rustauth_oauth/oauth2/
http.rs1use reqwest::{Client, Response};
2use serde_json::Value;
3use std::sync::OnceLock;
4use std::time::Duration;
5
6use super::error::{oauth_error_description, OAuthError};
7use super::request::OAuthFormRequest;
8use super::ssrf::{ssrf_guarded_client_builder, url_host_is_blocked_ip};
9
10const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
11const DEFAULT_USER_AGENT: &str = concat!("rustauth-oauth/", env!("CARGO_PKG_VERSION"));
12const SENSITIVE_OAUTH_FIELDS: &[&str] = &[
13 "access_token",
14 "refresh_token",
15 "id_token",
16 "client_secret",
17 "client_assertion",
18 "subject_token",
19 "device_code",
20 "code",
21 "token",
22 "authorization",
23];
24
25#[derive(Debug, Clone)]
26pub struct OAuthHttpClient {
27 client: Client,
28 allow_private_ips: bool,
32}
33
34#[derive(Debug, Clone, PartialEq, Eq)]
35pub struct OAuthHttpClientConfig {
36 pub timeout: Duration,
37 pub user_agent: Option<String>,
38 pub allow_private_ips: bool,
43}
44
45impl Default for OAuthHttpClientConfig {
46 fn default() -> Self {
47 Self {
48 timeout: DEFAULT_TIMEOUT,
49 user_agent: Some(DEFAULT_USER_AGENT.to_owned()),
50 allow_private_ips: false,
51 }
52 }
53}
54
55impl OAuthHttpClient {
56 pub fn new(client: Client) -> Self {
64 Self {
65 client,
66 allow_private_ips: true,
67 }
68 }
69
70 pub fn reqwest_client(&self) -> &Client {
76 &self.client
77 }
78
79 pub fn default_client() -> Result<Self, OAuthError> {
80 Self::from_config(OAuthHttpClientConfig::default())
81 }
82
83 pub fn from_config(config: OAuthHttpClientConfig) -> Result<Self, OAuthError> {
84 if config.timeout.is_zero() {
85 return Err(OAuthError::InvalidConfiguration(
86 "HTTP timeout must be greater than zero".to_owned(),
87 ));
88 }
89 let mut builder = if config.allow_private_ips {
90 Client::builder()
91 } else {
92 ssrf_guarded_client_builder()
93 }
94 .timeout(config.timeout);
95 if let Some(user_agent) = config.user_agent {
96 builder = builder.user_agent(user_agent);
97 }
98 builder
99 .build()
100 .map(|client| Self {
101 client,
102 allow_private_ips: config.allow_private_ips,
103 })
104 .map_err(Into::into)
105 }
106
107 fn ensure_request_url_allowed(&self, url: &str) -> Result<(), OAuthError> {
111 if !self.allow_private_ips && url_host_is_blocked_ip(url) {
112 return Err(OAuthError::BlockedRequestUrl);
113 }
114 Ok(())
115 }
116
117 pub async fn get_bytes(&self, url: &str) -> Result<Vec<u8>, OAuthError> {
118 self.get_bytes_with_headers(url, &[]).await
119 }
120
121 pub async fn get_bytes_with_headers(
122 &self,
123 url: &str,
124 headers: &[(&str, &str)],
125 ) -> Result<Vec<u8>, OAuthError> {
126 self.ensure_request_url_allowed(url)?;
127 let mut builder = self.client.get(url).header("accept", "application/json");
128 for (key, value) in headers {
129 builder = builder.header(*key, *value);
130 }
131 let response = builder.send().await?;
132 response_bytes(response).await
133 }
134
135 pub async fn post_form(
136 &self,
137 token_endpoint: &str,
138 request: OAuthFormRequest,
139 ) -> Result<Value, OAuthError> {
140 self.ensure_request_url_allowed(token_endpoint)?;
141 let mut builder = self.client.post(token_endpoint);
142 for (key, value) in &request.headers {
143 builder = builder.header(key, value);
144 }
145 let response = builder.body(request.to_form_urlencoded()).send().await?;
146 response_json(response).await
147 }
148}
149
150pub fn default_http_client() -> Result<OAuthHttpClient, OAuthError> {
151 static CLIENT: OnceLock<Result<OAuthHttpClient, String>> = OnceLock::new();
152
153 CLIENT
154 .get_or_init(|| OAuthHttpClient::default_client().map_err(|error| error.to_string()))
155 .clone()
156 .map_err(OAuthError::InvalidConfiguration)
157}
158
159async fn response_bytes(response: Response) -> Result<Vec<u8>, OAuthError> {
160 let status = response.status();
161 let bytes = response.bytes().await?;
162 if status.is_success() {
163 return Ok(bytes.to_vec());
164 }
165 Err(http_status_error(status.as_u16(), &bytes))
166}
167
168async fn response_json(response: Response) -> Result<Value, OAuthError> {
169 let status = response.status();
170 let bytes = response.bytes().await?;
171 let value = serde_json::from_slice::<Value>(&bytes);
172 if status.is_success() {
173 return value.map_err(|error| OAuthError::InvalidResponse(error.to_string()));
174 }
175 if let Ok(value) = value {
176 if let Some(error) = value.get("error").and_then(Value::as_str) {
177 return Err(OAuthError::ErrorResponse {
178 error: error.to_owned(),
179 description: oauth_error_description(redact_error_description(
180 value.get("error_description").and_then(Value::as_str),
181 )),
182 uri: value
183 .get("error_uri")
184 .and_then(Value::as_str)
185 .map(str::to_owned),
186 });
187 }
188 }
189 Err(http_status_error(status.as_u16(), &bytes))
190}
191
192fn http_status_error(status: u16, body: &[u8]) -> OAuthError {
193 OAuthError::HttpStatus {
194 status,
195 body: redact_body(&String::from_utf8_lossy(body)),
196 }
197}
198
199fn redact_body(body: &str) -> String {
200 if let Ok(mut value) = serde_json::from_str::<Value>(body) {
201 redact_json_value(&mut value);
202 return value.to_string();
203 }
204
205 let lower = body.to_ascii_lowercase();
206 if SENSITIVE_OAUTH_FIELDS.iter().any(|key| lower.contains(key))
207 || lower.contains("bearer ")
208 || lower.contains("basic ")
209 {
210 return "<redacted OAuth response body>".to_owned();
211 }
212 body.to_owned()
213}
214
215fn redact_json_value(value: &mut Value) {
216 match value {
217 Value::Object(object) => {
218 for (key, value) in object {
219 if SENSITIVE_OAUTH_FIELDS
220 .iter()
221 .any(|sensitive| key.eq_ignore_ascii_case(sensitive))
222 {
223 *value = Value::String("<redacted>".to_owned());
224 } else {
225 redact_json_value(value);
226 }
227 }
228 }
229 Value::Array(values) => {
230 for value in values {
231 redact_json_value(value);
232 }
233 }
234 _ => {}
235 }
236}
237
238fn redact_error_description(description: Option<&str>) -> Option<String> {
239 let description = description?;
240 let lower = description.to_ascii_lowercase();
241 if [
242 "access_token",
243 "refresh_token",
244 "id_token",
245 "client_secret",
246 "client_assertion",
247 "subject_token",
248 "device_code",
249 "authorization",
250 "bearer ",
251 "basic ",
252 ]
253 .iter()
254 .any(|needle| lower.contains(needle))
255 {
256 return Some("<redacted error_description>".to_owned());
257 }
258 Some(description.to_owned())
259}