use std::fmt;
use crate::{
error::{Error, Result},
util::handle_oauth2_error_response,
};
use reqwest::Client;
use serde::Deserialize;
use url::Url;
#[derive(Clone, Copy, Debug, Default)]
pub struct Permission {
write: bool,
access_shared: bool,
offline_access: bool,
}
impl Permission {
#[must_use]
pub fn new_read() -> Self {
Self::default()
}
#[must_use]
pub fn write(mut self, write: bool) -> Self {
self.write = write;
self
}
#[must_use]
pub fn access_shared(mut self, access_shared: bool) -> Self {
self.access_shared = access_shared;
self
}
#[must_use]
pub fn offline_access(mut self, offline_access: bool) -> Self {
self.offline_access = offline_access;
self
}
#[must_use]
#[rustfmt::skip]
fn to_scope_string(self) -> String {
format!(
"{}{}{}",
if self.write { "files.readwrite" } else { "files.read" },
if self.access_shared { ".all" } else { "" },
if self.offline_access { " offline_access" } else { "" },
)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Tenant {
Common,
Organizations,
Consumers,
Issuer(String),
}
impl Tenant {
fn to_issuer(&self) -> &str {
match self {
Tenant::Common => "common",
Tenant::Organizations => "organizations",
Tenant::Consumers => "consumers",
Tenant::Issuer(s) => s,
}
}
}
#[derive(Debug, Clone)]
pub struct Auth {
client: Client,
client_id: String,
permission: Permission,
redirect_uri: String,
tenant: Tenant,
}
impl Auth {
pub fn new(
client_id: impl Into<String>,
permission: Permission,
redirect_uri: impl Into<String>,
tenant: Tenant,
) -> Self {
Self::new_with_client(Client::new(), client_id, permission, redirect_uri, tenant)
}
pub fn new_with_client(
client: Client,
client_id: impl Into<String>,
permission: Permission,
redirect_uri: impl Into<String>,
tenant: Tenant,
) -> Self {
Self {
client,
client_id: client_id.into(),
permission,
redirect_uri: redirect_uri.into(),
tenant,
}
}
#[must_use]
pub fn client(&self) -> &Client {
&self.client
}
#[must_use]
pub fn client_id(&self) -> &str {
&self.client_id
}
#[must_use]
pub fn permission(&self) -> &Permission {
&self.permission
}
#[must_use]
pub fn redirect_uri(&self) -> &str {
&self.redirect_uri
}
#[must_use]
pub fn tenant(&self) -> &Tenant {
&self.tenant
}
#[must_use]
fn endpoint_url(&self, endpoint: &str) -> Url {
let mut url = Url::parse("https://login.microsoftonline.com").unwrap();
url.path_segments_mut().unwrap().extend([
self.tenant.to_issuer(),
"oauth2",
"v2.0",
endpoint,
]);
url
}
#[must_use]
pub fn code_auth_url(&self) -> Url {
let mut url = self.endpoint_url("authorize");
url.query_pairs_mut()
.append_pair("client_id", &self.client_id)
.append_pair("scope", &self.permission.to_scope_string())
.append_pair("redirect_uri", &self.redirect_uri)
.append_pair("response_type", "code");
url
}
async fn request_token<'a>(
&self,
require_refresh: bool,
params: impl Iterator<Item = (&'a str, &'a str)>,
) -> Result<TokenResponse> {
let url = self.endpoint_url("token");
let params = params.collect::<Vec<_>>();
let resp = self.client.post(url).form(¶ms).send().await?;
let token_resp: TokenResponse = handle_oauth2_error_response(resp).await?.json().await?;
if require_refresh && token_resp.refresh_token.is_none() {
return Err(Error::unexpected_response("Missing field `refresh_token`"));
}
Ok(token_resp)
}
pub async fn login_with_code(
&self,
code: &str,
client_credential: &ClientCredential,
) -> Result<TokenResponse> {
self.request_token(
self.permission.offline_access,
[
("client_id", &self.client_id as &str),
("code", code),
("grant_type", "authorization_code"),
("redirect_uri", &self.redirect_uri),
]
.into_iter()
.chain(client_credential.params()),
)
.await
}
pub async fn login_with_refresh_token(
&self,
refresh_token: &str,
client_credential: &ClientCredential,
) -> Result<TokenResponse> {
assert!(
self.permission.offline_access,
"Refresh token requires offline_access permission."
);
self.request_token(
true,
[
("client_id", &self.client_id as &str),
("grant_type", "refresh_token"),
("redirect_uri", &self.redirect_uri),
("refresh_token", refresh_token),
]
.into_iter()
.chain(client_credential.params()),
)
.await
}
}
#[derive(Default, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum ClientCredential {
#[default]
None,
Secret(String),
Assertion(String),
}
impl fmt::Debug for ClientCredential {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::None => write!(f, "None"),
Self::Secret(_) => f.debug_struct("Secret").finish_non_exhaustive(),
Self::Assertion(_) => f.debug_struct("Assertion").finish_non_exhaustive(),
}
}
}
impl ClientCredential {
fn params(&self) -> impl Iterator<Item = (&str, &str)> {
let (a, b) = match self {
ClientCredential::None => (None, None),
ClientCredential::Secret(s) => (Some(("client_secret", &**s)), None),
ClientCredential::Assertion(s) => (
Some((
"client_assertion_type",
"urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
)),
Some(("client_assertion", &**s)),
),
};
a.into_iter().chain(b)
}
}
#[derive(Clone, Deserialize)]
#[non_exhaustive]
pub struct TokenResponse {
pub token_type: String,
#[serde(deserialize_with = "space_separated_strings")]
pub scope: Vec<String>,
#[serde(rename = "expires_in")]
pub expires_in_secs: u64,
pub access_token: String,
pub refresh_token: Option<String>,
}
impl fmt::Debug for TokenResponse {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TokenResponse")
.field("token_type", &self.token_type)
.field("scope", &self.scope)
.field("expires_in_secs", &self.expires_in_secs)
.finish_non_exhaustive()
}
}
fn space_separated_strings<'de, D>(deserializer: D) -> std::result::Result<Vec<String>, D::Error>
where
D: serde::de::Deserializer<'de>,
{
struct Visitor;
impl serde::de::Visitor<'_> for Visitor {
type Value = Vec<String>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("space-separated strings")
}
fn visit_str<E>(self, s: &str) -> std::result::Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(s.split(' ').map(Into::into).collect())
}
}
deserializer.deserialize_str(Visitor)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn auth_url() {
let perm = Permission::new_read().write(true).offline_access(true);
let auth = Auth::new(
"some-client-id",
perm,
"http://example.com",
Tenant::Consumers,
);
assert_eq!(
auth.code_auth_url().as_str(),
"https://login.microsoftonline.com/consumers/oauth2/v2.0/authorize?client_id=some-client-id&scope=files.readwrite+offline_access&redirect_uri=http%3A%2F%2Fexample.com&response_type=code",
);
}
}