use oauth2::{
AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, EndpointNotSet, EndpointSet,
HttpClientError, RedirectUrl, RefreshToken, Scope, TokenResponse, TokenUrl,
basic::{BasicClient, BasicRequestTokenError, BasicTokenResponse},
};
use reqwest::Client;
use serde::Deserialize;
use std::path::PathBuf;
use url::Url;
use super::channel_messenger::{AuthContext, ChannelMessenger};
use crate::error::Error;
use crate::token::Token;
type RequestTokenError = BasicRequestTokenError<HttpClientError<reqwest::Error>>;
#[derive(Debug, Deserialize)]
pub(super) struct AuthRequest {
pub(super) code: String,
pub(super) state: String,
}
#[derive(Debug)]
pub(super) struct Authorizer<CM: ChannelMessenger> {
oauth2_client:
BasicClient<EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointSet>,
async_client: Client,
messenger: CM,
}
impl<CM: ChannelMessenger> Authorizer<CM> {
pub(super) async fn new(
app_key: String,
secret: String,
redirect_url: String,
async_client: Client,
messenger: CM,
) -> Result<Self, Error> {
let app_key = ClientId::new(app_key);
let secret = ClientSecret::new(secret);
let auth_url = AuthUrl::new("https://api.schwabapi.com/v1/oauth/authorize".to_string())
.expect("Invalid authorization endpoint URL");
let token_url = TokenUrl::new("https://api.schwabapi.com/v1/oauth/token".to_string())
.expect("Invalid token endpoint URL");
let redirect_url = RedirectUrl::new(redirect_url).expect("Invalid redirect URL");
let oauth2_client = BasicClient::new(app_key)
.set_client_secret(secret)
.set_auth_uri(auth_url)
.set_token_uri(token_url)
.set_redirect_uri(redirect_url);
let mut auth = Authorizer {
oauth2_client,
async_client,
messenger,
};
let context = auth.create_auth_context();
auth.messenger.with_context(context).await?;
Ok(auth)
}
async fn authorize(&self) -> Result<Token, Error> {
let auth_code = {
self.messenger.send_auth_message().await?;
AuthorizationCode::new(
self.messenger
.receive_auth_message()
.await
.expect("Failed to get auth message."),
)
};
let token_result = self
.refresh_token(auth_code)
.await
.map_err(|e| Error::Token(e.to_string()))?;
let token = Token {
refresh: token_result
.refresh_token()
.expect("should have refresh_token")
.secret()
.to_string(),
refresh_expires_in: chrono::Utc::now()
.checked_add_signed(super::REFRESH_TOKEN_LIFETIME)
.expect("refresh_expires_in"),
access: token_result.access_token().secret().to_string(),
access_expires_in: chrono::Utc::now()
.checked_add_signed(super::ACCESS_TOKEN_LIFETIME)
.expect("access_expires_in"),
type_: token_result.token_type().as_ref().to_string(),
};
Ok(token)
}
fn auth_code_url(&self) -> (Url, CsrfToken) {
let (auth_url, csrf_token) = self
.oauth2_client
.authorize_url(CsrfToken::new_random)
.add_scope(Scope::new("readonly".to_string()))
.url();
(auth_url, csrf_token)
}
async fn refresh_token(
&self,
auth_code: AuthorizationCode,
) -> Result<BasicTokenResponse, RequestTokenError> {
self.oauth2_client
.exchange_code(auth_code)
.request_async(&self.async_client)
.await
}
pub(super) async fn access_token(
&self,
refresh_token: &str,
) -> Result<BasicTokenResponse, RequestTokenError> {
let refresh_token = RefreshToken::new(refresh_token.to_string());
self.oauth2_client
.exchange_refresh_token(&refresh_token)
.request_async(&self.async_client)
.await
}
fn create_auth_context(&self) -> AuthContext {
let (auth_url, csrf_token) = self.auth_code_url();
AuthContext {
auth_url: Some(auth_url),
csrf: Some(csrf_token),
redirect_url: Some(
self.oauth2_client
.redirect_uri()
.expect("redirect_url")
.url()
.clone(),
),
}
}
pub(super) async fn save(&self, path: PathBuf) -> Result<Token, Error> {
let token = self
.authorize()
.await
.map_err(|e| Error::Token(e.to_string()))?;
token.save(path)?;
Ok(token)
}
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
use std::{borrow::Cow, collections::HashMap};
use crate::token::channel_messenger::compound_messenger::CompoundMessenger;
use crate::token::channel_messenger::local_server::LocalServerMessenger;
use crate::token::channel_messenger::stdio_messenger::StdioMessenger;
fn client_id_static() -> &'static str {
#[allow(clippy::option_env_unwrap)]
option_env!("SCHWAB_API_KEY")
.expect("The environment variable SCHWAB_API_KEY sholud be set")
}
fn secret_static() -> &'static str {
#[allow(clippy::option_env_unwrap)]
option_env!("SCHWAB_SECRET").expect("The environment variable SCHWAB_SECRET sholud be set")
}
fn callback_url_static() -> &'static str {
#[allow(clippy::option_env_unwrap)]
option_env!("SCHWAB_CALLBACK_URL")
.expect("The environment variable SCHWAB_CALLBACK_URL sholud be set")
}
#[tokio::test]
#[ignore = "Testing manually for compound verification. Should be --nocapture"]
async fn test_auth_compound() {
let certs_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/certs");
let messenger = CompoundMessenger::new(
LocalServerMessenger::new(&certs_dir).await,
StdioMessenger::new(),
);
let auth = Authorizer::new(
client_id_static().to_string(),
secret_static().to_string(),
callback_url_static().to_string(),
Client::new(),
messenger,
)
.await
.unwrap();
let token = auth.authorize().await.unwrap();
dbg!(&token);
let access_token = auth.access_token(&token.refresh).await.unwrap();
dbg!(&access_token);
}
#[tokio::test]
#[ignore = "Testing manually for browser verification. Should be --nocapture"]
async fn test_auth_local_server() {
let certs_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/certs");
let messenger = LocalServerMessenger::new(&certs_dir).await;
let auth = Authorizer::new(
client_id_static().to_string(),
secret_static().to_string(),
callback_url_static().to_string(),
Client::new(),
messenger,
)
.await
.unwrap();
let token = auth.authorize().await.unwrap();
dbg!(&token);
let access_token = auth.access_token(&token.refresh).await.unwrap();
dbg!(&access_token);
}
#[tokio::test]
#[ignore = "Testing manually for stdio verification. Should be --nocapture"]
async fn test_auth_stdio() {
let messenger = StdioMessenger::new();
let auth = Authorizer::new(
client_id_static().to_string(),
secret_static().to_string(),
callback_url_static().to_string(),
Client::new(),
messenger,
)
.await
.unwrap();
let token = auth.authorize().await.unwrap();
dbg!(&token);
let access_token = auth.access_token(&token.refresh).await.unwrap();
dbg!(&access_token);
}
#[tokio::test]
async fn test_get_auth_code_url() {
const CLIENTID: &str = "CLIENTID";
const SECRET: &str = "SECRET";
const REDIRECT_URL: &str = "https://127.0.0.1:8080";
let messenger = StdioMessenger::new();
let auth = Authorizer::new(
CLIENTID.to_string(),
SECRET.to_string(),
REDIRECT_URL.to_string(),
Client::new(),
messenger,
)
.await
.unwrap();
let (auth_url, csrf_token) = auth.auth_code_url();
println!("{auth_url:?}");
assert_eq!(auth_url.scheme(), "https");
assert_eq!(auth_url.host_str().unwrap(), "api.schwabapi.com");
assert_eq!(auth_url.path(), "/v1/oauth/authorize");
let pairs: HashMap<_, _> = auth_url.query_pairs().into_iter().collect();
assert_eq!(pairs.len(), 5);
assert_eq!(
pairs.get(&Cow::Borrowed("state")).unwrap(),
&Cow::Borrowed(csrf_token.secret().as_str())
);
assert_eq!(
pairs.get(&Cow::Borrowed("response_type")).unwrap(),
&Cow::Borrowed("code")
);
assert_eq!(
pairs.get(&Cow::Borrowed("client_id")).unwrap(),
&Cow::Borrowed(CLIENTID)
);
assert_eq!(
pairs.get(&Cow::Borrowed("redirect_uri")).unwrap(),
&Cow::Borrowed(REDIRECT_URL)
);
assert_eq!(
pairs.get(&Cow::Borrowed("scope")).unwrap(),
&Cow::Borrowed("readonly")
);
assert!(!csrf_token.secret().is_empty());
}
}