1use std::{fmt, io, net::TcpListener, vec};
7
8use oauth::v2_0::{AuthorizationCodeGrant, Client, RefreshAccessToken};
9use secret::Secret;
10use tracing::debug;
11
12#[doc(inline)]
13pub use super::{Error, Result};
14
15#[derive(Clone, Debug, Default, Eq, PartialEq)]
17#[cfg_attr(
18 feature = "derive",
19 derive(serde::Serialize, serde::Deserialize),
20 serde(rename_all = "kebab-case")
21)]
22pub struct OAuth2Config {
23 pub method: OAuth2Method,
26
27 pub client_id: String,
30
31 pub client_secret: Option<Secret>,
34
35 pub auth_url: String,
37
38 pub token_url: String,
40
41 #[cfg_attr(
44 feature = "derive",
45 serde(default, skip_serializing_if = "Secret::is_empty")
46 )]
47 pub access_token: Secret,
48
49 #[cfg_attr(
52 feature = "derive",
53 serde(default, skip_serializing_if = "Secret::is_empty")
54 )]
55 pub refresh_token: Secret,
56
57 pub pkce: bool,
61
62 pub redirect_scheme: Option<String>,
63 pub redirect_host: Option<String>,
64 pub redirect_port: Option<u16>,
65
66 #[cfg_attr(feature = "derive", serde(flatten))]
68 pub scopes: OAuth2Scopes,
69}
70
71impl OAuth2Config {
72 pub const LOCALHOST: &'static str = "localhost";
73
74 pub fn get_first_available_port() -> Result<u16> {
76 (49_152..65_535)
77 .find(|port| TcpListener::bind((OAuth2Config::LOCALHOST, *port)).is_ok())
78 .ok_or(Error::GetAvailablePortError)
79 }
80
81 pub async fn reset(&self) -> Result<()> {
83 if let Some(secret) = self.client_secret.as_ref() {
84 secret
85 .delete_if_keyring()
86 .await
87 .map_err(Error::DeleteClientSecretOauthError)?;
88 }
89
90 self.access_token
91 .delete_if_keyring()
92 .await
93 .map_err(Error::DeleteAccessTokenOauthError)?;
94 self.refresh_token
95 .delete_if_keyring()
96 .await
97 .map_err(Error::DeleteRefreshTokenOauthError)?;
98
99 Ok(())
100 }
101
102 pub async fn configure(
106 &self,
107 get_client_secret: impl Fn() -> io::Result<String>,
108 ) -> Result<()> {
109 if self.access_token.get().await.is_ok() {
110 return Ok(());
111 }
112
113 let redirect_scheme = match self.redirect_scheme.as_ref() {
114 Some(scheme) => scheme.clone(),
115 None => "http".into(),
116 };
117
118 let redirect_host = match self.redirect_host.as_ref() {
119 Some(host) => host.clone(),
120 None => OAuth2Config::LOCALHOST.to_owned(),
121 };
122
123 let redirect_port = match self.redirect_port {
124 Some(port) => port,
125 None => OAuth2Config::get_first_available_port()?,
126 };
127
128 let client_secret = match self.client_secret.as_ref() {
129 None => None,
130 Some(secret) => Some(match secret.find().await {
131 Ok(None) => {
132 debug!("cannot find oauth2 client secret from keyring, setting it");
133 secret
134 .set_if_keyring(
135 get_client_secret()
136 .map_err(Error::GetClientSecretFromUserOauthError)?,
137 )
138 .await
139 .map_err(Error::SetClientSecretIntoKeyringOauthError)
140 }
141 Ok(Some(client_secret)) => Ok(client_secret),
142 Err(err) => Err(Error::GetClientSecretFromKeyringOauthError(err)),
143 }?),
144 };
145
146 let client = Client::new(
147 self.client_id.clone(),
148 client_secret,
149 self.auth_url.clone(),
150 self.token_url.clone(),
151 redirect_scheme,
152 redirect_host,
153 redirect_port,
154 )
155 .map_err(Error::BuildOauthClientError)?;
156
157 let mut auth_code_grant = AuthorizationCodeGrant::new();
158
159 if self.pkce {
160 auth_code_grant = auth_code_grant.with_pkce();
161 }
162
163 for scope in self.scopes.clone() {
164 auth_code_grant = auth_code_grant.with_scope(scope);
165 }
166
167 let (redirect_url, csrf_token) = auth_code_grant.get_redirect_url(&client);
168
169 println!("To complete your OAuth 2.0 setup, click on the following link:");
170 println!();
171 println!("{}", redirect_url);
172
173 let (access_token, refresh_token) = auth_code_grant
174 .wait_for_redirection(&client, csrf_token)
175 .await
176 .map_err(Error::WaitForOauthRedirectionError)?;
177
178 self.access_token
179 .set_if_keyring(access_token)
180 .await
181 .map_err(Error::SetAccessTokenOauthError)?;
182
183 if let Some(refresh_token) = &refresh_token {
184 self.refresh_token
185 .set_if_keyring(refresh_token)
186 .await
187 .map_err(Error::SetRefreshTokenOauthError)?;
188 }
189
190 Ok(())
191 }
192
193 pub async fn refresh_access_token(&self) -> Result<String> {
196 let redirect_scheme = match self.redirect_scheme.as_ref() {
197 Some(scheme) => scheme.clone(),
198 None => "http".into(),
199 };
200
201 let redirect_host = match self.redirect_host.as_ref() {
202 Some(host) => host.clone(),
203 None => OAuth2Config::LOCALHOST.to_owned(),
204 };
205
206 let redirect_port = match self.redirect_port {
207 Some(port) => port,
208 None => OAuth2Config::get_first_available_port()?,
209 };
210
211 let client_secret = match self.client_secret.as_ref() {
212 None => None,
213 Some(secret) => {
214 let secret = secret
215 .get()
216 .await
217 .map_err(Error::GetClientSecretFromKeyringOauthError)?;
218 Some(secret)
219 }
220 };
221
222 let client = Client::new(
223 self.client_id.clone(),
224 client_secret,
225 self.auth_url.clone(),
226 self.token_url.clone(),
227 redirect_scheme,
228 redirect_host,
229 redirect_port,
230 )
231 .map_err(Error::BuildOauthClientError)?;
232
233 let refresh_token = self
234 .refresh_token
235 .get()
236 .await
237 .map_err(Error::GetRefreshTokenOauthError)?;
238
239 let (access_token, refresh_token) = RefreshAccessToken::new()
240 .refresh_access_token(&client, refresh_token)
241 .await
242 .map_err(Error::RefreshAccessTokenOauthError)?;
243
244 self.access_token
245 .set_if_keyring(&access_token)
246 .await
247 .map_err(Error::SetAccessTokenOauthError)?;
248
249 if let Some(refresh_token) = &refresh_token {
250 self.refresh_token
251 .set_if_keyring(refresh_token)
252 .await
253 .map_err(Error::SetRefreshTokenOauthError)?;
254 }
255
256 Ok(access_token)
257 }
258
259 pub async fn access_token(&self) -> Result<String> {
262 self.access_token
263 .get()
264 .await
265 .map_err(Error::GetAccessTokenOauthError)
266 }
267}
268
269#[derive(Clone, Debug, Default, Eq, PartialEq)]
272#[cfg_attr(
273 feature = "derive",
274 derive(serde::Serialize, serde::Deserialize),
275 serde(rename_all = "lowercase")
276)]
277pub enum OAuth2Method {
278 #[default]
279 #[cfg_attr(feature = "derive", serde(alias = "XOAUTH2"))]
280 XOAuth2,
281 #[cfg_attr(feature = "derive", serde(alias = "OAUTHBEARER"))]
282 OAuthBearer,
283}
284
285impl fmt::Display for OAuth2Method {
286 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
287 match self {
288 Self::XOAuth2 => write!(f, "XOAUTH2"),
289 Self::OAuthBearer => write!(f, "OAUTHBEARER"),
290 }
291 }
292}
293
294#[derive(Clone, Debug, Eq, PartialEq)]
296#[cfg_attr(
297 feature = "derive",
298 derive(serde::Serialize, serde::Deserialize),
299 serde(rename_all = "kebab-case")
300)]
301pub enum OAuth2Scopes {
302 Scope(String),
303 Scopes(Vec<String>),
304}
305
306impl Default for OAuth2Scopes {
307 fn default() -> Self {
308 Self::Scopes(Vec::new())
309 }
310}
311
312impl IntoIterator for OAuth2Scopes {
313 type IntoIter = vec::IntoIter<Self::Item>;
314 type Item = String;
315
316 fn into_iter(self) -> Self::IntoIter {
317 match self {
318 Self::Scope(scope) => vec![scope].into_iter(),
319 Self::Scopes(scopes) => scopes.into_iter(),
320 }
321 }
322}