1use std::time::Duration;
2
3use base64::prelude::*;
4use query_string_builder::QueryString;
5use reqwest::{RequestBuilder, StatusCode, header::HeaderMap};
6use serde::{Deserialize, Serialize};
7use sha2::Digest;
8
9pub mod error;
10pub mod x;
11
12pub use reqwest;
13
14use crate::error::OAuth2Error;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct TokenResult {
18 pub access_token: String,
19 pub refresh_token: String,
20 pub expires_in: u64,
21 pub scope: String,
22 pub token_type: String,
23}
24
25pub(crate) enum ResponseType {
26 Code,
27 #[allow(unused)]
28 Token,
29}
30
31impl std::fmt::Display for ResponseType {
32 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
33 match self {
34 Self::Code => write!(f, "code"),
35 Self::Token => write!(f, "token"),
36 }
37 }
38}
39
40pub(crate) enum CodeChallengeMethod {
41 S256,
42 #[allow(unused)]
43 Plain,
44}
45
46impl std::fmt::Display for CodeChallengeMethod {
47 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
48 match self {
49 Self::S256 => write!(f, "S256"),
50 Self::Plain => write!(f, "plain"),
51 }
52 }
53}
54
55pub(crate) struct PkceS256 {
56 pub code_challenge: String,
57 pub code_verifier: String,
58}
59
60impl PkceS256 {
61 pub fn new() -> Self {
62 let size = 32;
63 let random_bytes: Vec<u8> = (0..size).map(|_| rand::random::<u8>()).collect();
64 let code_verifier = BASE64_URL_SAFE_NO_PAD.encode(&random_bytes);
65 let code_challenge = {
66 let hash = sha2::Sha256::digest(code_verifier.as_bytes());
67 BASE64_URL_SAFE_NO_PAD.encode(hash)
68 };
69 Self {
70 code_challenge,
71 code_verifier,
72 }
73 }
74}
75
76impl Default for PkceS256 {
77 fn default() -> Self {
78 Self::new()
79 }
80}
81
82#[allow(clippy::too_many_arguments)]
83pub(crate) fn authorize_url(
84 url: &str,
85 response_type: ResponseType,
86 client_id: &str,
87 redirect_uri: &str,
88 scopes: &str,
89 state: &str,
90 code_challenge: &str,
91 code_challenge_method: CodeChallengeMethod,
92) -> String {
93 let qs = QueryString::dynamic()
94 .with_value("response_type", response_type.to_string())
95 .with_value("client_id", client_id)
96 .with_value("redirect_uri", redirect_uri)
97 .with_value("scope", scopes)
98 .with_value("state", state)
99 .with_value("code_challenge", code_challenge)
100 .with_value("code_challenge_method", code_challenge_method.to_string());
101 format!("{}{}", url, qs)
102}
103
104#[allow(clippy::too_many_arguments)]
105pub(crate) async fn token(
106 url: &str,
107 client_id: &str,
108 client_secret: &str,
109 redirect_uri: &str,
110 code: &str,
111 code_verifier: &str,
112 grant_type: &str,
113 timeout: Duration,
114 try_count: usize,
115 retry_millis: u64,
116) -> Result<(TokenResult, StatusCode, HeaderMap), OAuth2Error> {
117 let params = [
118 ("grant_type", grant_type),
119 ("code", code),
120 ("redirect_uri", redirect_uri),
121 ("client_id", client_id),
122 ("code_verifier", code_verifier),
123 ];
124
125 let client = reqwest::Client::new();
126
127 execute_retry(
128 || {
129 client
130 .post(url)
131 .form(¶ms)
132 .basic_auth(client_id, Some(client_secret))
133 .timeout(timeout)
134 },
135 try_count,
136 retry_millis,
137 )
138 .await
139}
140
141pub(crate) async fn execute_retry<T>(
142 f: impl Fn() -> RequestBuilder,
143 try_count: usize,
144 retry_millis: u64,
145) -> Result<(T, StatusCode, HeaderMap), OAuth2Error>
146where
147 T: serde::de::DeserializeOwned,
148{
149 for i in 0..try_count {
150 let req = f();
151 let res = req.send().await?;
152 let status = res.status();
153 let headers = res.headers().clone();
154 if status.is_success() {
155 let json: T = res.json().await?;
156 return Ok((json, status, headers));
157 } else if status.is_client_error() {
158 let body = res.text().await.unwrap_or_default();
159 return Err(OAuth2Error::ClientError(body, status, headers));
160 }
161 if i + 1 < try_count {
162 let jitter: u64 = rand::random::<u64>() % retry_millis;
164 let exp_backoff = 2u64.pow(i as u32) * retry_millis;
165 let retry_duration = Duration::from_millis(exp_backoff + jitter);
166 tokio::time::sleep(retry_duration).await;
167 } else {
168 let body = res.text().await.unwrap_or_default();
169 return Err(OAuth2Error::RetryOver(body, status, headers));
170 }
171 }
172 unreachable!()
173}
174
175#[cfg(test)]
176mod tests {
177 use crate::x::{X_AUTHORIZE_URL, XScope};
178
179 use super::*;
180
181 #[tokio::test]
183 async fn test_authorize() {
184 let client_id = std::env::var("CLIENT_ID").unwrap();
185 let redirect_url = std::env::var("REDIRECT_URL").unwrap();
186 let state = "test_state";
187 let scopes = XScope::scopes_to_string(&XScope::all());
188 let code_challenge = "test_code_challenge";
189 let res = authorize_url(
190 X_AUTHORIZE_URL,
191 ResponseType::Code,
192 &client_id,
193 &redirect_url,
194 &scopes,
195 &state,
196 &code_challenge,
197 CodeChallengeMethod::Plain,
198 );
199 println!("res: {}", res);
200 }
201}