use tokio::runtime::{Handle, Runtime};
use crate::{
accounts::types::{AccountNumber, GetAccountBalancesResponse},
accounts::{api::blocking::Accounts, api::non_blocking::Accounts as NonBlockingAccounts},
client::non_blocking::TradierRestClient as AsyncClient,
user::{api::blocking::User, api::non_blocking::User as NonBlockingUser, UserProfileResponse},
utils::Sealed,
Config, Result,
};
#[derive(Debug)]
pub struct BlockingTradierRestClient {
rest_client: AsyncClient,
runtime: Runtime,
}
impl BlockingTradierRestClient {
pub fn new(config: Config) -> Result<Self> {
if Handle::try_current().is_ok() {
return Err(crate::Error::BlockingClientInsideAsyncRuntime);
}
Ok(Self {
rest_client: AsyncClient::new(config),
runtime: tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?,
})
}
}
impl Sealed for BlockingTradierRestClient {}
impl User for BlockingTradierRestClient {
fn get_user_profile(&self) -> Result<UserProfileResponse> {
self.runtime.block_on(self.rest_client.get_user_profile())
}
}
impl Accounts for BlockingTradierRestClient {
fn get_account_balances(
&self,
account_number: &AccountNumber,
) -> Result<GetAccountBalancesResponse> {
self.runtime
.block_on(self.rest_client.get_account_balances(account_number))
}
fn get_account_positions(
&self,
account_number: &AccountNumber,
) -> Result<crate::types::GetAccountPositionsResponse> {
self.runtime
.block_on(self.rest_client.get_account_positions(account_number))
}
}
#[cfg(test)]
mod test {
use super::*;
use std::cell::RefCell;
use crate::{
accounts::test_support::{GetAccountBalancesResponseWire, GetAccountPositionsResponseWire},
user::test_support::GetUserProfileResponseWire,
utils::tests::with_env_vars,
Config,
};
use httpmock::MockServer;
use proptest::prelude::*;
#[test]
fn test_blocking_client() {
let server = MockServer::start();
let server = RefCell::new(server);
proptest!(|(response in any::<GetUserProfileResponseWire>())| {
let server = server.borrow_mut();
let mut operation = server.mock(|when, then| {
when.path(url::Url::parse(&server.url("/v1/user/profile")).unwrap().path())
.header("accept", "application/json");
then.status(200)
.header("content-type", "application/json")
.body(serde_json::to_vec(&response)
.expect("serialization of wire type for tests to work"));
});
with_env_vars(vec![("TRADIER_REST_BASE_URL", &server.base_url()),
("TRADIER_ACCESS_TOKEN", "testToken")], || {
let config = Config::new();
let sut = BlockingTradierRestClient::new(config).expect("client to initialize");
let response = sut.get_user_profile();
operation.assert();
assert_eq!(operation.calls(), 1);
assert!(response.is_ok());
operation.delete();
});
});
proptest!(|(response in any::<GetAccountBalancesResponseWire>(),
ascii_string in prop::collection::vec(0x20u8..0x7fu8, 1..256)
.prop_flat_map(|vec| {
Just(vec.into_iter().map(|c| c as char).collect::<String>())
})
.prop_filter("Strings must not be empty or blank", |v| !v.trim().is_empty()))| {
let server = server.borrow_mut();
let mut operation = server.mock(|when, then| {
when.path(url::Url::parse(&server.url(format!("/v1/accounts/{ascii_string}/balances"))).unwrap().path())
.header("accept", "application/json");
then.status(200)
.header("content-type", "application/json")
.body(serde_json::to_vec(&response)
.expect("serialization of wire type for tests to work"));
});
with_env_vars(vec![("TRADIER_REST_BASE_URL", &server.base_url()),
("TRADIER_ACCESS_TOKEN", "testToken")], || {
let config = Config::new();
let sut = BlockingTradierRestClient::new(config).expect("client to initialize");
let response = sut.get_account_balances(&ascii_string.parse().expect("valid ascii"));
operation.assert();
assert_eq!(operation.calls(), 1);
assert!(response.is_ok());
operation.delete();
});
});
proptest!(|(response in any::<GetAccountPositionsResponseWire>(),
ascii_string in prop::collection::vec(0x20u8..0x7fu8, 1..256)
.prop_flat_map(|vec| {
Just(vec.into_iter().map(|c| c as char).collect::<String>())
})
.prop_filter("Strings must not be empty or blank", |v| !v.trim().is_empty()))| {
let server = server.borrow_mut();
let mut operation = server.mock(|when, then| {
when.path(url::Url::parse(&server.url(format!("/v1/accounts/{ascii_string}/positions"))).unwrap().path())
.header("accept", "application/json");
then.status(200)
.header("content-type", "application/json")
.body(serde_json::to_vec(&response)
.expect("serialization of wire type for tests to work"));
});
with_env_vars(vec![("TRADIER_REST_BASE_URL", &server.base_url()),
("TRADIER_ACCESS_TOKEN", "testToken")], || {
let config = Config::new();
let sut = BlockingTradierRestClient::new(config).expect("client to initialize");
let response = sut.get_account_positions(&ascii_string.parse().expect("valid ascii"));
operation.assert();
assert_eq!(operation.calls(), 1);
assert!(response.is_ok());
operation.delete();
});
});
}
#[tokio::test]
async fn test_should_not_be_able_to_create_within_an_async_runtime() {
let config = Config::new();
let sut = BlockingTradierRestClient::new(config);
assert!(sut.is_err());
}
}