1use std::{collections::BTreeMap, fmt};
2
3use reqwest::header::CONTENT_TYPE;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use url::Url;
7
8use crate::{
9 ClientBuilder, CloudConvertClient, Error, Result,
10 config::{OAuthAccessToken, OAuthClientSecret, OAuthRefreshToken},
11};
12
13const OAUTH_AUTHORIZE_URL: &str = "https://cloudconvert.com/oauth/authorize";
14const OAUTH_TOKEN_URL: &str = "https://cloudconvert.com/oauth/token";
15
16#[derive(Clone, Debug, Eq, PartialEq)]
17#[non_exhaustive]
18pub enum OAuthScope {
19 UserRead,
20 UserWrite,
21 TaskRead,
22 TaskWrite,
23 WebhookRead,
24 WebhookWrite,
25 Other(String),
26}
27
28impl OAuthScope {
29 pub fn as_str(&self) -> &str {
30 match self {
31 Self::UserRead => "user.read",
32 Self::UserWrite => "user.write",
33 Self::TaskRead => "task.read",
34 Self::TaskWrite => "task.write",
35 Self::WebhookRead => "webhook.read",
36 Self::WebhookWrite => "webhook.write",
37 Self::Other(value) => value.as_str(),
38 }
39 }
40}
41
42impl From<&str> for OAuthScope {
43 fn from(value: &str) -> Self {
44 match value {
45 "user.read" => Self::UserRead,
46 "user.write" => Self::UserWrite,
47 "task.read" => Self::TaskRead,
48 "task.write" => Self::TaskWrite,
49 "webhook.read" => Self::WebhookRead,
50 "webhook.write" => Self::WebhookWrite,
51 _ => Self::Other(value.to_string()),
52 }
53 }
54}
55
56impl From<String> for OAuthScope {
57 fn from(value: String) -> Self {
58 Self::from(value.as_str())
59 }
60}
61
62impl Serialize for OAuthScope {
63 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
64 where
65 S: serde::Serializer,
66 {
67 serializer.serialize_str(self.as_str())
68 }
69}
70
71#[derive(Clone)]
72pub struct OAuthClient {
73 client_id: String,
74 client_secret: OAuthClientSecret,
75 authorize_url: Url,
76 token_url: Url,
77 http: reqwest::Client,
78}
79
80impl OAuthClient {
81 pub fn new(client_id: impl Into<String>, client_secret: OAuthClientSecret) -> Result<Self> {
82 Self::with_http_client(client_id, client_secret, reqwest::Client::new())
83 }
84
85 pub fn with_http_client(
86 client_id: impl Into<String>,
87 client_secret: OAuthClientSecret,
88 http: reqwest::Client,
89 ) -> Result<Self> {
90 Ok(Self {
91 client_id: client_id.into(),
92 client_secret,
93 authorize_url: Url::parse(OAUTH_AUTHORIZE_URL)?,
94 token_url: Url::parse(OAUTH_TOKEN_URL)?,
95 http,
96 })
97 }
98
99 pub fn with_endpoints(mut self, authorize_url: Url, token_url: Url) -> Self {
100 self.authorize_url = authorize_url;
101 self.token_url = token_url;
102 self
103 }
104
105 pub fn client_id(&self) -> &str {
106 self.client_id.as_str()
107 }
108
109 pub fn authorize_url(&self) -> &Url {
110 &self.authorize_url
111 }
112
113 pub fn token_url(&self) -> &Url {
114 &self.token_url
115 }
116
117 pub fn authorization_code_url(
118 &self,
119 redirect_uri: impl AsRef<str>,
120 scopes: impl IntoIterator<Item = OAuthScope>,
121 ) -> Result<Url> {
122 self.authorization_url("code", redirect_uri.as_ref(), scopes, None)
123 }
124
125 pub fn authorization_code_url_with_state(
126 &self,
127 redirect_uri: impl AsRef<str>,
128 scopes: impl IntoIterator<Item = OAuthScope>,
129 state: impl Into<String>,
130 ) -> Result<Url> {
131 self.authorization_url("code", redirect_uri.as_ref(), scopes, Some(state.into()))
132 }
133
134 pub fn implicit_url(
135 &self,
136 redirect_uri: impl AsRef<str>,
137 scopes: impl IntoIterator<Item = OAuthScope>,
138 ) -> Result<Url> {
139 self.authorization_url("token", redirect_uri.as_ref(), scopes, None)
140 }
141
142 pub fn implicit_url_with_state(
143 &self,
144 redirect_uri: impl AsRef<str>,
145 scopes: impl IntoIterator<Item = OAuthScope>,
146 state: impl Into<String>,
147 ) -> Result<Url> {
148 self.authorization_url("token", redirect_uri.as_ref(), scopes, Some(state.into()))
149 }
150
151 pub async fn exchange_code(
152 &self,
153 code: impl Into<String>,
154 redirect_uri: impl Into<String>,
155 ) -> Result<OAuthTokenResponse> {
156 self.send_token_request(vec![
157 ("grant_type", "authorization_code".to_string()),
158 ("code", code.into()),
159 ("redirect_uri", redirect_uri.into()),
160 ("client_id", self.client_id.clone()),
161 ("client_secret", self.client_secret.expose().to_string()),
162 ])
163 .await
164 }
165
166 pub async fn refresh_access_token(
167 &self,
168 refresh_token: &OAuthRefreshToken,
169 ) -> Result<OAuthTokenResponse> {
170 self.send_token_request(vec![
171 ("grant_type", "refresh_token".to_string()),
172 ("refresh_token", refresh_token.expose().to_string()),
173 ("client_id", self.client_id.clone()),
174 ("client_secret", self.client_secret.expose().to_string()),
175 ])
176 .await
177 }
178
179 fn authorization_url(
180 &self,
181 response_type: &str,
182 redirect_uri: &str,
183 scopes: impl IntoIterator<Item = OAuthScope>,
184 state: Option<String>,
185 ) -> Result<Url> {
186 let mut url = self.authorize_url.clone();
187 let scope = scopes
188 .into_iter()
189 .map(|scope| scope.as_str().to_string())
190 .collect::<Vec<_>>()
191 .join(" ");
192
193 {
194 let mut query = url.query_pairs_mut();
195 query
196 .append_pair("response_type", response_type)
197 .append_pair("client_id", &self.client_id)
198 .append_pair("redirect_uri", redirect_uri);
199 if !scope.is_empty() {
200 query.append_pair("scope", &scope);
201 }
202 if let Some(state) = state {
203 query.append_pair("state", &state);
204 }
205 }
206
207 Ok(url)
208 }
209
210 async fn send_token_request(
211 &self,
212 form: Vec<(&'static str, String)>,
213 ) -> Result<OAuthTokenResponse> {
214 let body = form_body(&form);
215 let response = self
216 .http
217 .post(self.token_url.clone())
218 .header(CONTENT_TYPE, "application/x-www-form-urlencoded")
219 .body(body)
220 .send()
221 .await?;
222
223 if !response.status().is_success() {
224 return Err(oauth_api_error(response).await);
225 }
226
227 let raw = response.json::<RawOAuthTokenResponse>().await?;
228 Ok(raw.into())
229 }
230}
231
232impl fmt::Debug for OAuthClient {
233 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
234 formatter
235 .debug_struct("OAuthClient")
236 .field("client_id", &self.client_id)
237 .field("client_secret", &self.client_secret)
238 .field("authorize_url", &self.authorize_url)
239 .field("token_url", &self.token_url)
240 .field("http", &"reqwest::Client")
241 .finish()
242 }
243}
244
245#[derive(Clone)]
246#[non_exhaustive]
247pub struct OAuthTokenResponse {
248 pub access_token: OAuthAccessToken,
249 pub refresh_token: Option<OAuthRefreshToken>,
250 pub token_type: Option<String>,
251 pub expires_in: Option<u64>,
252 pub scope: Option<String>,
253 pub extra: BTreeMap<String, Value>,
254}
255
256impl OAuthTokenResponse {
257 pub fn access_token(&self) -> &OAuthAccessToken {
258 &self.access_token
259 }
260
261 pub fn refresh_token(&self) -> Option<&OAuthRefreshToken> {
262 self.refresh_token.as_ref()
263 }
264
265 pub fn into_access_token(self) -> OAuthAccessToken {
266 self.access_token
267 }
268
269 pub fn client_builder(&self) -> ClientBuilder {
270 CloudConvertClient::builder_with_access_token(self.access_token.clone())
271 }
272
273 pub fn into_client_builder(self) -> ClientBuilder {
274 CloudConvertClient::builder_with_access_token(self.access_token)
275 }
276}
277
278impl fmt::Debug for OAuthTokenResponse {
279 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
280 formatter
281 .debug_struct("OAuthTokenResponse")
282 .field("access_token", &self.access_token)
283 .field("refresh_token", &self.refresh_token)
284 .field("token_type", &self.token_type)
285 .field("expires_in", &self.expires_in)
286 .field("scope", &self.scope)
287 .field("extra", &self.extra)
288 .finish()
289 }
290}
291
292#[derive(Deserialize)]
293struct RawOAuthTokenResponse {
294 access_token: String,
295 #[serde(default)]
296 refresh_token: Option<String>,
297 #[serde(default)]
298 token_type: Option<String>,
299 #[serde(default)]
300 expires_in: Option<u64>,
301 #[serde(default)]
302 scope: Option<String>,
303 #[serde(flatten)]
304 extra: BTreeMap<String, Value>,
305}
306
307impl From<RawOAuthTokenResponse> for OAuthTokenResponse {
308 fn from(value: RawOAuthTokenResponse) -> Self {
309 Self {
310 access_token: OAuthAccessToken::new(value.access_token),
311 refresh_token: value.refresh_token.map(OAuthRefreshToken::new),
312 token_type: value.token_type,
313 expires_in: value.expires_in,
314 scope: value.scope,
315 extra: value.extra,
316 }
317 }
318}
319
320fn form_body(form: &[(&'static str, String)]) -> String {
321 let mut serializer = url::form_urlencoded::Serializer::new(String::new());
322 for (key, value) in form {
323 serializer.append_pair(key, value);
324 }
325 serializer.finish()
326}
327
328async fn oauth_api_error(response: reqwest::Response) -> Error {
329 let status = response.status().as_u16();
330 let body = response.text().await.unwrap_or_default();
331 let parsed = serde_json::from_str::<Value>(&body).ok();
332 let message = parsed
333 .as_ref()
334 .and_then(|body| {
335 body.get("error_description")
336 .or_else(|| body.get("message"))
337 .and_then(Value::as_str)
338 })
339 .filter(|message| !message.is_empty())
340 .unwrap_or("OAuth token request failed")
341 .to_string();
342 let code = parsed
343 .as_ref()
344 .and_then(|body| {
345 body.get("error")
346 .or_else(|| body.get("code"))
347 .and_then(Value::as_str)
348 })
349 .map(ToString::to_string);
350
351 Error::Api {
352 status,
353 message,
354 code,
355 errors: parsed.map(Box::new),
356 rate_limit: None,
357 }
358}