use crate::async_client_trait::NoauthClient;
use crate::client_helpers::{parse_response, prepare_request};
use crate::client_trait_common::{Endpoint, ParamsType, Style};
use crate::Error;
use async_lock::RwLock;
use base64::engine::general_purpose::{URL_SAFE, URL_SAFE_NO_PAD};
use base64::Engine;
use ring::rand::{SecureRandom, SystemRandom};
use std::env;
use std::io::{self, IsTerminal, Write};
use std::sync::Arc;
use url::form_urlencoded::Serializer as UrlEncoder;
use url::Url;
#[derive(Debug, Clone)]
pub enum Oauth2Type {
AuthorizationCode {
client_secret: String,
},
PKCE(PkceCode),
ImplicitGrant,
}
impl Oauth2Type {
pub(crate) fn response_type_str(&self) -> &'static str {
match self {
Oauth2Type::AuthorizationCode { .. } | Oauth2Type::PKCE { .. } => "code",
Oauth2Type::ImplicitGrant => "token",
}
}
}
#[derive(Debug, Copy, Clone)]
pub enum TokenType {
ShortLivedAndRefresh,
ShortLived,
#[deprecated]
LongLived,
}
impl TokenType {
pub(crate) fn token_access_type_str(self) -> Option<&'static str> {
match self {
TokenType::ShortLivedAndRefresh => Some("offline"),
TokenType::ShortLived => Some("online"),
#[allow(deprecated)]
TokenType::LongLived => None,
}
}
}
#[derive(Debug, Clone)]
pub struct PkceCode {
pub code: String,
}
impl PkceCode {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
let mut bytes = [0u8; 93];
SystemRandom::new()
.fill(&mut bytes)
.expect("failed to get random bytes for PKCE");
let code = URL_SAFE.encode(bytes);
Self { code }
}
pub fn s256(&self) -> String {
let digest = ring::digest::digest(&ring::digest::SHA256, self.code.as_bytes());
URL_SAFE_NO_PAD.encode(digest.as_ref())
}
}
#[derive(Debug)]
pub struct AuthorizeUrlBuilder<'a> {
client_id: &'a str,
flow_type: &'a Oauth2Type,
token_type: TokenType,
force_reapprove: bool,
force_reauthentication: bool,
disable_signup: bool,
redirect_uri: Option<&'a str>,
state: Option<&'a str>,
require_role: Option<&'a str>,
locale: Option<&'a str>,
scope: Option<&'a str>,
}
impl<'a> AuthorizeUrlBuilder<'a> {
pub fn new(client_id: &'a str, flow_type: &'a Oauth2Type) -> Self {
Self {
client_id,
flow_type,
token_type: TokenType::ShortLivedAndRefresh,
force_reapprove: false,
force_reauthentication: false,
disable_signup: false,
redirect_uri: None,
state: None,
require_role: None,
locale: None,
scope: None,
}
}
pub fn force_reapprove(mut self, value: bool) -> Self {
self.force_reapprove = value;
self
}
pub fn force_reauthentication(mut self, value: bool) -> Self {
self.force_reauthentication = value;
self
}
pub fn disable_signup(mut self, value: bool) -> Self {
self.disable_signup = value;
self
}
pub fn redirect_uri(mut self, value: &'a str) -> Self {
self.redirect_uri = Some(value);
self
}
pub fn state(mut self, value: &'a str) -> Self {
self.state = Some(value);
self
}
pub fn require_role(mut self, value: &'a str) -> Self {
self.require_role = Some(value);
self
}
pub fn locale(mut self, value: &'a str) -> Self {
self.locale = Some(value);
self
}
pub fn token_type(mut self, value: TokenType) -> Self {
self.token_type = value;
self
}
pub fn scope(mut self, value: &'a str) -> Self {
self.scope = Some(value);
self
}
pub fn build(self) -> Url {
let mut url = Url::parse("https://www.dropbox.com/oauth2/authorize").unwrap();
{
let mut params = url.query_pairs_mut();
params.append_pair("response_type", self.flow_type.response_type_str());
params.append_pair("client_id", self.client_id);
if let Some(val) = self.token_type.token_access_type_str() {
params.append_pair("token_access_type", val);
}
if self.force_reapprove {
params.append_pair("force_reapprove", "true");
}
if self.force_reauthentication {
params.append_pair("force_reauthentication", "true");
}
if self.disable_signup {
params.append_pair("disable_signup", "true");
}
if let Some(value) = self.redirect_uri {
params.append_pair("redirect_uri", value);
}
if let Some(value) = self.state {
params.append_pair("state", value);
}
if let Some(value) = self.require_role {
params.append_pair("require_role", value);
}
if let Some(value) = self.locale {
params.append_pair("locale", value);
}
if let Some(value) = self.scope {
params.append_pair("scope", value);
}
if let Oauth2Type::PKCE(code) = self.flow_type {
params.append_pair("code_challenge", &code.s256());
params.append_pair("code_challenge_method", "S256");
}
}
url
}
}
#[derive(Debug, Clone)]
enum AuthorizationState {
InitialAuth {
flow_type: Oauth2Type,
auth_code: String,
redirect_uri: Option<String>,
},
Refresh {
refresh_token: String,
client_secret: Option<String>,
},
AccessToken {
client_secret: Option<String>,
token: String,
},
}
#[derive(Debug, Clone)]
pub struct Authorization {
pub client_id: String,
state: AuthorizationState,
}
impl Authorization {
pub fn client_id(&self) -> &str {
&self.client_id
}
pub fn from_auth_code(
client_id: String,
flow_type: Oauth2Type,
auth_code: String,
redirect_uri: Option<String>,
) -> Self {
Self {
client_id,
state: AuthorizationState::InitialAuth {
flow_type,
auth_code,
redirect_uri,
},
}
}
pub fn save(&self) -> Option<String> {
match &self.state {
AuthorizationState::AccessToken {
token,
client_secret,
} if client_secret.is_none() => {
Some(format!("1&{}", token))
}
AuthorizationState::Refresh { refresh_token, .. } => {
Some(format!("2&{}", refresh_token))
}
_ => None,
}
}
pub fn load(client_id: String, saved: &str) -> Option<Self> {
Some(match saved.get(0..2) {
Some("1&") =>
{
#[allow(deprecated)]
Self::from_long_lived_access_token(saved[2..].to_owned())
}
Some("2&") => Self::from_refresh_token(client_id, saved[2..].to_owned()),
_ => {
error!("unrecognized saved Authorization representation: {saved:?}");
return None;
}
})
}
pub fn from_refresh_token(client_id: String, refresh_token: String) -> Self {
Self {
client_id,
state: AuthorizationState::Refresh {
refresh_token,
client_secret: None,
},
}
}
pub fn from_client_secret_refresh_token(
client_id: String,
client_secret: String,
refresh_token: String,
) -> Self {
Self {
client_id,
state: AuthorizationState::Refresh {
refresh_token,
client_secret: Some(client_secret),
},
}
}
#[deprecated]
pub fn from_long_lived_access_token(access_token: String) -> Self {
Self {
client_id: String::new(),
state: AuthorizationState::AccessToken {
token: access_token,
client_secret: None,
},
}
}
if_feature! { "sync_routes",
pub fn obtain_access_token(
&mut self,
sync_client: impl crate::client_trait::NoauthClient
) -> Result<String, Error> {
use futures::FutureExt;
self.obtain_access_token_async(sync_client)
.now_or_never()
.expect("sync client future should resolve immediately")
}
}
pub async fn obtain_access_token_async(
&mut self,
client: impl NoauthClient,
) -> Result<String, Error> {
let mut redirect_uri = None;
let mut client_secret = None;
let mut pkce_code = None;
let mut refresh_token = None;
let mut auth_code = None;
match self.state.clone() {
AuthorizationState::AccessToken {
token,
client_secret: secret,
} => {
match secret {
None => {
return Ok(token);
}
Some(secret) => {
client_secret = Some(secret);
}
}
}
AuthorizationState::InitialAuth {
flow_type,
auth_code: code,
redirect_uri: uri,
} => {
match flow_type {
Oauth2Type::ImplicitGrant => {
self.state = AuthorizationState::AccessToken {
client_secret: None,
token: code.clone(),
};
return Ok(code);
}
Oauth2Type::AuthorizationCode {
client_secret: secret,
} => {
client_secret = Some(secret);
}
Oauth2Type::PKCE(pkce) => {
pkce_code = Some(pkce.code.clone());
}
}
auth_code = Some(code);
redirect_uri = uri;
}
AuthorizationState::Refresh {
refresh_token: refresh,
client_secret: secret,
} => {
refresh_token = Some(refresh);
if let Some(secret) = secret {
client_secret = Some(secret);
}
}
}
let params = {
let mut params = UrlEncoder::new(String::new());
if let Some(refresh) = &refresh_token {
params.append_pair("grant_type", "refresh_token");
params.append_pair("refresh_token", refresh);
} else {
params.append_pair("grant_type", "authorization_code");
params.append_pair("code", &auth_code.unwrap());
}
params.append_pair("client_id", &self.client_id);
if let Some(client_secret) = client_secret.as_deref() {
params.append_pair("client_secret", client_secret);
}
if let Some(pkce) = &pkce_code {
params.append_pair("code_verifier", pkce);
}
if refresh_token.is_none() {
if let Some(pkce) = pkce_code {
params.append_pair("code_verifier", &pkce);
} else {
params.append_pair(
"client_secret",
client_secret
.as_ref()
.expect("need either PKCE code or client secret"),
);
}
}
if let Some(value) = redirect_uri {
params.append_pair("redirect_uri", &value);
}
params.finish()
};
let (req, body) = prepare_request(
&client,
Endpoint::OAuth2,
Style::Rpc,
"oauth2/token",
params,
ParamsType::Form,
None,
None,
None,
None,
None,
);
let body = body.unwrap_or_default();
debug!("Requesting OAuth2 token");
let resp = client.execute(req, body).await?;
let (result_json, _, _) = parse_response(resp, Style::Rpc).await?;
let result_value = serde_json::from_str(&result_json)?;
debug!("OAuth2 response: {:?}", result_value);
let access_token: String;
let refresh_token: Option<String>;
match result_value {
serde_json::Value::Object(mut map) => {
match map.remove("access_token") {
Some(serde_json::Value::String(token)) => access_token = token,
_ => {
return Err(Error::UnexpectedResponse(
"no access token in response!".to_owned(),
))
}
}
match map.remove("refresh_token") {
Some(serde_json::Value::String(refresh)) => refresh_token = Some(refresh),
Some(_) => {
return Err(Error::UnexpectedResponse(
"refresh token is not a string!".to_owned(),
));
}
None => refresh_token = None,
}
}
_ => {
return Err(Error::UnexpectedResponse(
"response is not a JSON object".to_owned(),
))
}
}
match refresh_token {
Some(refresh) => {
self.state = AuthorizationState::Refresh {
refresh_token: refresh,
client_secret,
};
}
None if !matches!(self.state, AuthorizationState::Refresh { .. }) => {
self.state = AuthorizationState::AccessToken {
token: access_token.clone(),
client_secret,
};
}
_ => (),
}
Ok(access_token)
}
}
pub struct TokenCache {
auth: RwLock<(Authorization, Arc<String>)>,
}
impl TokenCache {
pub fn new(auth: Authorization) -> Self {
Self {
auth: RwLock::new((auth, Arc::new(String::new()))),
}
}
pub fn get_token(&self) -> Option<Arc<String>> {
let read = self.auth.read_blocking();
if read.1.is_empty() {
None
} else {
Some(Arc::clone(&read.1))
}
}
pub async fn update_token(
&self,
client: impl NoauthClient,
old_token: Arc<String>,
) -> Result<Arc<String>, Error> {
let mut write = self.auth.write().await;
if write.1 == old_token {
write.1 = Arc::new(write.0.obtain_access_token_async(client).await?);
}
Ok(Arc::clone(&write.1))
}
pub fn set_access_token(&self, access_token: String) {
let mut write = self.auth.write_blocking();
write.1 = Arc::new(access_token);
}
}
pub fn get_auth_from_env_or_prompt() -> Authorization {
if let Ok(long_lived) = env::var("DBX_OAUTH_TOKEN") {
#[allow(deprecated)]
return Authorization::from_long_lived_access_token(long_lived);
}
if let (Ok(client_id), Ok(saved)) = (env::var("DBX_CLIENT_ID"), env::var("DBX_OAUTH"))
{
match Authorization::load(client_id, &saved) {
Some(auth) => return auth,
None => {
eprintln!("saved authorization in DBX_CLIENT_ID and DBX_OAUTH are invalid");
}
}
}
if !io::stdin().is_terminal() {
panic!("DBX_CLIENT_ID and/or DBX_OAUTH not set, and stdin not a TTY; cannot authorize");
}
fn prompt(msg: &str) -> String {
eprint!("{}: ", msg);
io::stderr().flush().unwrap();
let mut input = String::new();
io::stdin().read_line(&mut input).unwrap();
input.trim().to_owned()
}
let client_id = prompt("Give me a Dropbox API app key");
let oauth2_flow = Oauth2Type::PKCE(PkceCode::new());
let url = AuthorizeUrlBuilder::new(&client_id, &oauth2_flow).build();
eprintln!("Open this URL in your browser:");
eprintln!("{}", url);
eprintln!();
let auth_code = prompt("Then paste the code here");
Authorization::from_auth_code(client_id, oauth2_flow, auth_code.trim().to_owned(), None)
}