#[macro_use] extern crate serde;
pub mod params;
pub mod responses;
use chrono::Utc;
use params::{
GetAccountParams,
GetAccountsParams,
GetMoversParams,
GetPriceHistoryParams,
};
use thiserror::Error;
use std::io;
pub const TDA_API_BASE: &str = "https://api.tdameritrade.com/v1";
#[derive(Debug)]
pub struct Client {
pub access_token: Option<AccessToken>,
client_id: String,
refresh_token: String,
}
impl<'a> Client {
pub fn new(client_id: &'a str, refresh_token: &'a str, access_token: Option<AccessToken>) -> Self {
Self {
access_token,
client_id: client_id.to_string(),
refresh_token: refresh_token.to_string(),
}
}
pub fn set_access_token(&mut self, access_token: &Option<AccessToken>) -> &mut Self {
self.access_token = access_token.clone();
self
}
pub fn get_access_token(&self) -> Result<responses::AccessTokenResponse, ClientError> {
let url = format!("{}/oauth2/token", TDA_API_BASE);
let response = ureq::post(&url)
.send_form(&[
("grant_type", "refresh_token"),
("refresh_token", &self.refresh_token),
("client_id", &self.client_id),
]);
let status = response.status();
let body = response.into_string().map_err(ClientError::ReadResponse)?;
if status != 200 {
return Err(ClientError::NotHttpOk(status, body))
}
serde_json::from_str(&body).map_err(ClientError::ParseResponse)
}
pub fn get_account(&self, account_id: &'a str, params: GetAccountParams) -> Result<responses::Account, ClientError> {
if self.access_token.is_none() {
panic!("Client does not have a token set!");
}
let access_token = self.access_token.as_ref().unwrap();
let url = format!("{}/accounts/{}", TDA_API_BASE, account_id);
let mut request = ureq::get(&url);
request.set("Authorization", &format!("Bearer {}", access_token.token));
if let Some(fields) = params.fields {
request.query("fields", &fields);
}
let response = request.call();
let status = response.status();
let body = response.into_string().map_err(ClientError::ReadResponse)?;
if status != 200 {
return Err(ClientError::NotHttpOk(status, body));
}
serde_json::from_str(&body).map_err(ClientError::ParseResponse)
}
pub fn get_accounts(&self, params: GetAccountsParams) -> Result<Vec<responses::Account>, ClientError> {
if self.access_token.is_none() {
panic!("Client does not have a token set!");
}
let access_token = self.access_token.as_ref().unwrap();
let url = format!("{}/accounts", TDA_API_BASE);
let mut request = ureq::get(&url);
request.set("Authorization", &format!("Bearer {}", access_token.token));
if let Some(fields) = params.fields {
request.query("fields", &fields);
}
let response = request.call();
let status = response.status();
let body = response.into_string().map_err(ClientError::ReadResponse)?;
if status != 200 {
return Err(ClientError::NotHttpOk(status, body));
}
serde_json::from_str(&body).map_err(ClientError::ParseResponse)
}
pub fn get_movers(&self, index: &'a str, params: GetMoversParams) -> Result<Vec<responses::Mover>, ClientError> {
if self.access_token.is_none() {
panic!("Client does not have a token set!");
}
let access_token = self.access_token.as_ref().unwrap();
let url = format!("{}/marketdata/{}/movers", TDA_API_BASE, index);
let mut request = ureq::get(&url);
request.set("Authorization", &format!("Bearer {}", access_token.token));
if let Some(direction) = params.direction {
request.query("direction", &direction);
}
if let Some(change) = params.change {
request.query("change", &change);
}
let response = request.call();
let status = response.status();
let body = response.into_string().map_err(ClientError::ReadResponse)?;
if status != 200 {
return Err(ClientError::NotHttpOk(status, body));
}
serde_json::from_str(&body).map_err(ClientError::ParseResponse)
}
pub fn get_price_history(&self, symbol: &str, params: GetPriceHistoryParams) -> Result<responses::GetPriceHistoryResponse, ClientError> {
if self.access_token.is_none() {
panic!("Client does not have a token set!");
}
let access_token = self.access_token.as_ref().unwrap();
let url = format!("{}/marketdata/{}/pricehistory", TDA_API_BASE, symbol);
let mut request = ureq::get(&url);
request.set("Authorization", &format!("Bearer {}", access_token.token));
if let Some(period_type) = params.period_type {
request.query("periodType", &period_type);
}
if let Some(period) = params.period {
request.query("period", &period);
}
if let Some(frequency_type) = params.frequency_type {
request.query("frequencyType", &frequency_type);
}
if let Some(frequency) = params.frequency {
request.query("frequency", &frequency);
}
if let Some(end_date) = params.end_date {
request.query("endDate", &end_date);
}
if let Some(start_date) = params.start_date {
request.query("startDate", &start_date);
}
if let Some(need_extended_hours_data) = params.need_extended_hours_data {
request.query("needExtendedHoursData", &need_extended_hours_data.to_string());
}
let response = request.call();
let status = response.status();
let body = response.into_string().map_err(ClientError::ReadResponse)?;
if status != 200 {
return Err(ClientError::NotHttpOk(status, body));
}
serde_json::from_str(&body).map_err(ClientError::ParseResponse)
}
}
#[derive(Clone, Debug, Serialize)]
pub struct AccessToken {
pub expires_at: i64,
pub scope: Vec<String>,
pub token: String,
}
impl From<responses::AccessTokenResponse> for AccessToken {
fn from(response: responses::AccessTokenResponse) -> Self {
let now = Utc::now().naive_utc().timestamp_millis();
Self {
token: response.access_token,
expires_at: now + response.expires_in,
scope: response.scope.split(' ').map(|v| v.to_string()).collect(),
}
}
}
impl AccessToken {
#[allow(dead_code)]
pub fn has_expired(&self) -> bool {
self.expires_at >= Utc::now().naive_utc().timestamp_millis()
}
}
#[derive(Debug, Error)]
pub enum ClientError {
#[error("Received a {0} HTTP code: {1}")]
NotHttpOk(u16, String),
#[error("Failed to parse response: {0}")]
ParseResponse(#[from] serde_json::error::Error),
#[error("Failed to read response string: {0}")]
ReadResponse(#[from] io::Error),
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs::{self, OpenOptions};
const CONFIG_FILE: &'static str = "./.test.env";
const TOKEN_FILE_PATH: &'static str = "./.token.json";
#[derive(Debug)]
struct Config {
tda_client_id: String,
tda_refresh_token: String,
}
fn get_working_client() -> Client {
let config = load_config();
let mut client = Client::new(&config.tda_client_id, &config.tda_refresh_token, None);
let mut token: AccessToken = match OpenOptions::new().open(TOKEN_FILE_PATH) {
Ok(_) => load_token().into(),
Err(_) => {
let token: AccessToken = client.get_access_token().unwrap().into();
save_token(&token);
token
},
};
if token.has_expired() {
token = client.get_access_token().unwrap().into();
save_token(&token);
}
client.set_access_token(&Some(token));
client
}
fn load_config() -> Config {
dotenv::from_path(CONFIG_FILE).ok();
Config {
tda_client_id: dotenv::var("TDA_CLIENT_ID").unwrap(),
tda_refresh_token: dotenv::var("TDA_REFRESH_TOKEN").unwrap(),
}
}
fn load_token() -> responses::AccessTokenResponse {
let token = fs::read_to_string(TOKEN_FILE_PATH).unwrap();
serde_json::from_str(&token).unwrap()
}
fn save_token(token: &AccessToken) {
fs::write(TOKEN_FILE_PATH, serde_json::to_string(&token).unwrap()).unwrap();
}
#[test]
fn get_access_token() {
let config = load_config();
let client = Client::new(&config.tda_client_id, &config.tda_refresh_token, None);
let token = client.get_access_token().unwrap();
assert_ne!(token.access_token.len(), 0);
}
#[test]
fn set_access_token() {
let config = load_config();
let mut client = Client::new(&config.tda_client_id, &config.tda_refresh_token, None);
let response = client.get_access_token().unwrap();
let new_access_token = response.access_token.clone();
client.set_access_token(&Some(response.into()));
assert_eq!(new_access_token, client.access_token.unwrap().token);
}
#[test]
fn get_account() {
let client = get_working_client();
let accounts = client.get_accounts(GetAccountsParams::default()).unwrap();
match &accounts.get(0).unwrap().securities_account {
responses::SecuritiesAccount::MarginAccount { account_id, .. } => {
client.get_account(account_id, GetAccountParams::default()).unwrap();
}
}
}
#[test]
fn get_accounts() {
let client = get_working_client();
let accounts = client.get_accounts(GetAccountsParams::default()).unwrap();
assert_ne!(accounts.len(), 0);
}
#[test]
fn get_movers() {
let client = get_working_client();
let _movers = client.get_movers("$DJI", GetMoversParams::default()).unwrap();
}
#[test]
fn get_price_history() {
let client = get_working_client();
let response = client.get_price_history("AAPL", GetPriceHistoryParams::default()).unwrap();
assert_ne!(response.candles.len(), 0);
}
}