use std::collections::HashMap;
use std::time::Duration;
use crate::error::AuthError;
use crate::{consts, error::AuthResult};
use oauth2::basic::BasicClient;
use oauth2::devicecode::{DeviceAuthorizationResponse, ExtraDeviceAuthorizationFields};
use oauth2::reqwest::async_http_client;
use oauth2::{
AuthType, AuthUrl, ClientId, DeviceAuthorizationUrl, RefreshToken, Scope, TokenResponse,
TokenUrl,
};
use serde::{Deserialize, Serialize};
use tokio::time::sleep;
#[derive(Clone, Debug)]
pub struct MicrosoftToken {
pub access_token: String,
pub refresh_token: String,
}
#[cfg_attr(doc_cfg, doc(cfg(feature = "auth")))]
#[derive(Clone, Debug)]
pub struct DeviceCodeInfo {
client: BasicClient,
details: StoringDeviceAuthorizationResponse,
}
impl DeviceCodeInfo {
#[instrument(name = "finish_ms_authentication", level = "trace", skip_all)]
pub(crate) async fn finish_authentication(self) -> AuthResult<MicrosoftToken> {
trace!("Poll for the access token");
let token = self
.client
.exchange_device_access_token(&self.details)
.request_async(async_http_client, sleep, None)
.await
.map_err(|err| AuthError::RequestTokenError(err.to_string()))?;
Ok(MicrosoftToken {
access_token: token.access_token().secret().to_owned(),
refresh_token: token
.refresh_token()
.map(|t| t.secret().to_owned())
.ok_or(AuthError::NoRefreshToken)?,
})
}
pub fn user_code(&self) -> String {
self.details.user_code().secret().to_owned()
}
pub fn verification_url(&self) -> String {
self.details.verification_uri().to_string()
}
pub fn expires_in(&self) -> Duration {
self.details.expires_in()
}
}
#[instrument(name = "setup_ms_authentication", level = "trace", skip_all)]
pub async fn setup_authentication(client_id: String) -> AuthResult<DeviceCodeInfo> {
let client_id = ClientId::new(client_id);
trace!("Parse auth URL");
let auth_url = AuthUrl::new(consts::MS_AUTH_URL.to_string())?;
trace!("Parse token URL");
let token_url = TokenUrl::new(consts::MS_TOKEN_URL.to_string())?;
trace!("Parse device auth URL");
let device_auth_url = DeviceAuthorizationUrl::new(consts::MS_DEVICE_AUTH_URL.to_string())?;
trace!("Setup OAuth client");
let client = BasicClient::new(client_id, None, auth_url, Some(token_url))
.set_device_authorization_url(device_auth_url)
.set_auth_type(AuthType::RequestBody);
trace!("Retrieve user code and verification URL (Start authentication process)");
let details: StoringDeviceAuthorizationResponse = client
.exchange_device_code()?
.add_scopes([
Scope::new(String::from("XboxLive.signin")),
Scope::new(String::from("XboxLive.offline_access")),
])
.request_async(async_http_client)
.await
.map_err(|err| AuthError::RequestTokenError(err.to_string()))?;
Ok(DeviceCodeInfo { client, details })
}
#[instrument(name = "refresh_ms_authentication", level = "trace", skip_all)]
pub async fn refresh_token(client_id: String, refresh_token: String) -> AuthResult<MicrosoftToken> {
let client_id = ClientId::new(client_id);
let refresh_token = RefreshToken::new(refresh_token);
trace!("Parse auth URL");
let auth_url = AuthUrl::new(consts::MS_AUTH_URL.to_string())?;
trace!("Parse token URL");
let token_url = TokenUrl::new(consts::MS_TOKEN_URL.to_string())?;
trace!("Setup OAuth client");
let client = BasicClient::new(client_id, None, auth_url, Some(token_url))
.set_auth_type(AuthType::RequestBody);
trace!("Refresh token");
let token = client
.exchange_refresh_token(&refresh_token)
.request_async(async_http_client)
.await
.map_err(|err| AuthError::RequestTokenError(err.to_string()))?;
Ok(MicrosoftToken {
access_token: token.access_token().secret().to_owned(),
refresh_token: token
.refresh_token()
.map(|t| t.secret().to_owned())
.ok_or(AuthError::NoRefreshToken)?,
})
}
#[derive(Clone, Debug, Serialize, Deserialize)]
struct StoringFields(HashMap<String, serde_json::Value>);
impl ExtraDeviceAuthorizationFields for StoringFields {}
type StoringDeviceAuthorizationResponse = DeviceAuthorizationResponse<StoringFields>;