use std::sync::Arc;
use reqwest::{Method, RequestBuilder};
use serde::Serialize;
use serde::de::DeserializeOwned;
use crate::accounts::Accounts;
use crate::constants::{MARKET_DATA_BASE_URL, TRADER_BASE_URL};
use crate::error::{Error, Result, map_response_to_error};
use crate::market_data::MarketData;
use crate::orders::{AllOrders, Orders};
use crate::secrets::{AccountHash, AuthToken};
use crate::streamer::{self, ReadHalf, WriteHalf};
use crate::token::{StaticTokenProvider, TokenProvider};
use crate::transactions::Transactions;
use crate::user_preferences::UserPreferences;
#[derive(Clone)]
pub struct SchwabClient {
client: reqwest::Client,
trader_base_url: String,
market_data_base_url: String,
token_provider: Arc<dyn TokenProvider + Send + Sync>,
}
impl std::fmt::Debug for SchwabClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SchwabClient")
.field("trader_base_url", &self.trader_base_url)
.field("market_data_base_url", &self.market_data_base_url)
.field("token_provider", &"<dyn TokenProvider>")
.finish()
}
}
impl SchwabClient {
pub fn new(auth_token: AuthToken) -> Self {
Self::with_token_provider(Arc::new(StaticTokenProvider::new(auth_token)))
}
pub fn with_token_provider(provider: Arc<dyn TokenProvider + Send + Sync>) -> Self {
Self {
client: reqwest::Client::new(),
trader_base_url: TRADER_BASE_URL.to_string(),
market_data_base_url: MARKET_DATA_BASE_URL.to_string(),
token_provider: provider,
}
}
pub fn with_trader_base_url(mut self, url: impl Into<String>) -> Result<Self> {
let url = url.into();
validate_base_url(&url, cfg!(debug_assertions))?;
self.trader_base_url = url;
Ok(self)
}
pub fn with_market_data_base_url(mut self, url: impl Into<String>) -> Result<Self> {
let url = url.into();
validate_base_url(&url, cfg!(debug_assertions))?;
self.market_data_base_url = url;
Ok(self)
}
pub fn accounts(&self) -> Accounts<'_> {
Accounts::new(self)
}
pub fn orders<'a, 'b>(&'a self, account_hash: &'b AccountHash) -> Orders<'a, 'b> {
Orders::new(self, account_hash)
}
pub fn orders_all(&self) -> AllOrders<'_> {
AllOrders::new(self)
}
pub fn transactions<'a, 'b>(&'a self, account_hash: &'b AccountHash) -> Transactions<'a, 'b> {
Transactions::new(self, account_hash)
}
pub fn user_preferences(&self) -> UserPreferences<'_> {
UserPreferences::new(self)
}
pub fn market_data(&self) -> MarketData<'_> {
MarketData::new(self)
}
pub async fn streamer(&self) -> Result<(ReadHalf, WriteHalf)> {
let preferences = self
.user_preferences()
.get()
.await?
.into_iter()
.next()
.ok_or(Error::InvalidPreference {
field: "userPreference",
reason: "empty response".to_string(),
})?;
let streamer_info =
preferences
.streamer_info
.into_iter()
.next()
.ok_or(Error::InvalidPreference {
field: "streamerInfo",
reason: "missing".to_string(),
})?;
streamer::connect(streamer_info, self.token_provider.clone()).await
}
pub(crate) fn trader_http(&self) -> Transport<'_> {
Transport {
client: self,
base_url: &self.trader_base_url,
}
}
pub(crate) fn market_data_http(&self) -> Transport<'_> {
Transport {
client: self,
base_url: &self.market_data_base_url,
}
}
}
pub(crate) struct Transport<'a> {
client: &'a SchwabClient,
base_url: &'a str,
}
impl<'a> Transport<'a> {
fn request(&self, method: Method, path: &str) -> AuthedRequest<'a> {
AuthedRequest {
builder: self
.client
.client
.request(method, format!("{}{}", self.base_url, path)),
provider: &*self.client.token_provider,
}
}
pub(crate) fn get(&self, path: &str) -> AuthedRequest<'a> {
self.request(Method::GET, path)
}
pub(crate) fn post(&self, path: &str) -> AuthedRequest<'a> {
self.request(Method::POST, path)
}
pub(crate) fn put(&self, path: &str) -> AuthedRequest<'a> {
self.request(Method::PUT, path)
}
pub(crate) fn delete(&self, path: &str) -> AuthedRequest<'a> {
self.request(Method::DELETE, path)
}
pub(crate) async fn get_json<T: DeserializeOwned>(&self, path: &str) -> Result<T> {
self.get(path).send_json().await
}
}
pub(crate) struct AuthedRequest<'a> {
builder: RequestBuilder,
provider: &'a (dyn TokenProvider + Send + Sync),
}
impl<'a> AuthedRequest<'a> {
pub(crate) fn query<Q: Serialize + ?Sized>(mut self, q: &Q) -> Self {
self.builder = self.builder.query(q);
self
}
pub(crate) fn json<T: Serialize + ?Sized>(mut self, body: &T) -> Self {
self.builder = self.builder.json(body);
self
}
pub(crate) async fn send(self) -> Result<reqwest::Response> {
let token = self.provider.access_token().await?;
let response = self
.builder
.bearer_auth(token.expose_secret())
.send()
.await?;
if response.status().is_success() {
Ok(response)
} else {
Err(map_response_to_error(response).await)
}
}
pub(crate) async fn send_json<T: DeserializeOwned>(self) -> Result<T> {
let response = self.send().await?;
let bytes = response.bytes().await?;
serde_json::from_slice(&bytes).map_err(|e| Error::Codec {
context: "decode response body".to_string(),
reason: e.to_string(),
})
}
}
fn validate_base_url(url: &str, allow_insecure: bool) -> Result<()> {
if url.starts_with("https://") {
return Ok(());
}
if allow_insecure && url.starts_with("http://") {
return Ok(());
}
Err(Error::InsecureBaseUrl {
url: url.to_string(),
reason: if allow_insecure {
"expected http:// or https://".to_string()
} else {
"release builds require https://".to_string()
},
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn https_is_accepted_in_both_modes() {
assert!(validate_base_url("https://api.schwabapi.com/trader/v1", false).is_ok());
assert!(validate_base_url("https://api.schwabapi.com/trader/v1", true).is_ok());
assert!(validate_base_url("https://127.0.0.1:8443/trader/v1", false).is_ok());
}
#[test]
fn http_is_rejected_when_insecure_disallowed() {
let err = validate_base_url("http://127.0.0.1:8080", false).unwrap_err();
match err {
Error::InsecureBaseUrl { url, reason } => {
assert_eq!(url, "http://127.0.0.1:8080");
assert!(
reason.contains("https://"),
"reason should name the required scheme: {reason}"
);
}
other => panic!("expected InsecureBaseUrl, got {other:?}"),
}
}
#[test]
fn http_is_accepted_when_insecure_permitted() {
assert!(validate_base_url("http://127.0.0.1:8080", true).is_ok());
assert!(validate_base_url("http://localhost/trader/v1", true).is_ok());
}
#[test]
fn other_schemes_are_always_rejected() {
for url in [
"ftp://example.com",
"ws://example.com",
"wss://example.com",
"javascript:alert(1)",
"file:///etc/passwd",
"",
"api.schwabapi.com/trader/v1",
"//api.schwabapi.com/trader/v1",
] {
assert!(
matches!(
validate_base_url(url, true).unwrap_err(),
Error::InsecureBaseUrl { .. }
),
"{url} should be rejected even with insecure mode on"
);
assert!(
matches!(
validate_base_url(url, false).unwrap_err(),
Error::InsecureBaseUrl { .. }
),
"{url} should be rejected with insecure mode off"
);
}
}
#[test]
fn case_sensitive_scheme_match() {
assert!(validate_base_url("HTTPS://api.schwabapi.com", true).is_err());
assert!(validate_base_url("Https://api.schwabapi.com", false).is_err());
}
}