use tracing::instrument;
use crate::client::Client;
use crate::datatables::{
DataTablesColumn, DataTablesRequest, DataTablesResponse, fetch_limit,
impl_datatables_request_methods,
};
use crate::error::Result;
use crate::models::Trade;
pub(crate) const INSTITUTIONAL_VOLUME_PATH: &str = "/InstitutionalVolume/GetInstitutionalVolume";
pub(crate) const AH_INSTITUTIONAL_VOLUME_PATH: &str =
"/AHInstitutionalVolume/GetAHInstitutionalVolume";
pub(crate) const TOTAL_VOLUME_PATH: &str = "/TotalVolume/GetTotalVolume";
#[derive(Clone, Debug)]
pub struct VolumeRequest(pub(crate) DataTablesRequest);
impl_datatables_request_methods!(VolumeRequest);
impl VolumeRequest {
#[must_use]
pub fn institutional() -> Self {
Self(DataTablesRequest {
columns: institutional_volume_columns(),
..DataTablesRequest::default()
})
}
#[must_use]
pub fn ah_institutional() -> Self {
Self(DataTablesRequest {
columns: ah_institutional_volume_columns(),
..DataTablesRequest::default()
})
}
#[must_use]
pub fn total() -> Self {
Self(DataTablesRequest {
columns: total_volume_columns(),
..DataTablesRequest::default()
})
}
#[must_use]
pub fn with_date(mut self, date: impl Into<String>) -> Self {
self.0 = self.0.with_extra_value("Date", date);
self
}
#[must_use]
pub fn with_tickers(mut self, tickers: impl Into<String>) -> Self {
self.0 = self.0.with_extra_value("Tickers", tickers);
self
}
pub(crate) fn to_pairs(&self) -> Vec<(String, String)> {
self.0.to_pairs()
}
}
fn volume_columns(volume: &str, dollars: &str, rank: &str) -> Vec<DataTablesColumn> {
vec![
DataTablesColumn::new("Ticker", "Ticker", true, true),
DataTablesColumn::new("Ticker", "Ticker", true, true),
DataTablesColumn::new("Price", "Price", true, true),
DataTablesColumn::new("Sector", "Sector", true, true),
DataTablesColumn::new("Industry", "Industry", true, true),
DataTablesColumn::new(volume, volume, true, true),
DataTablesColumn::new(dollars, dollars, true, true),
DataTablesColumn::new(rank, rank, true, true),
DataTablesColumn::new(
"LastComparibleTradeDate",
"LastComparibleTradeDate",
true,
true,
),
DataTablesColumn::new(
"LastComparibleTradeDate",
"LastComparibleTradeDate",
true,
true,
),
]
}
#[must_use]
pub fn institutional_volume_columns() -> Vec<DataTablesColumn> {
volume_columns(
"TotalInstitutionalVolume",
"TotalInstitutionalDollars",
"TotalInstitutionalDollarsRank",
)
}
#[must_use]
pub fn ah_institutional_volume_columns() -> Vec<DataTablesColumn> {
volume_columns(
"AHInstitutionalVolume",
"AHInstitutionalDollars",
"AHInstitutionalDollarsRank",
)
}
#[must_use]
pub fn total_volume_columns() -> Vec<DataTablesColumn> {
volume_columns("TotalVolume", "TotalDollars", "TotalDollarsRank")
}
impl Client {
#[instrument(skip_all)]
pub async fn get_institutional_volume(
&self,
request: &VolumeRequest,
) -> Result<DataTablesResponse<Trade>> {
let body = self
.post_form(INSTITUTIONAL_VOLUME_PATH, request.to_pairs())
.await?;
Ok(serde_json::from_str(&body)?)
}
#[instrument(skip_all)]
pub async fn get_institutional_volume_limit(
&self,
request: &VolumeRequest,
limit: usize,
) -> Result<Vec<Trade>> {
fetch_limit(self, INSTITUTIONAL_VOLUME_PATH, request.0.clone(), limit).await
}
#[instrument(skip_all)]
pub async fn get_ah_institutional_volume(
&self,
request: &VolumeRequest,
) -> Result<DataTablesResponse<Trade>> {
let body = self
.post_form(AH_INSTITUTIONAL_VOLUME_PATH, request.to_pairs())
.await?;
Ok(serde_json::from_str(&body)?)
}
#[instrument(skip_all)]
pub async fn get_ah_institutional_volume_limit(
&self,
request: &VolumeRequest,
limit: usize,
) -> Result<Vec<Trade>> {
fetch_limit(self, AH_INSTITUTIONAL_VOLUME_PATH, request.0.clone(), limit).await
}
#[instrument(skip_all)]
pub async fn get_total_volume(
&self,
request: &VolumeRequest,
) -> Result<DataTablesResponse<Trade>> {
let body = self
.post_form(TOTAL_VOLUME_PATH, request.to_pairs())
.await?;
Ok(serde_json::from_str(&body)?)
}
#[instrument(skip_all)]
pub async fn get_total_volume_limit(
&self,
request: &VolumeRequest,
limit: usize,
) -> Result<Vec<Trade>> {
fetch_limit(self, TOTAL_VOLUME_PATH, request.0.clone(), limit).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::client::ClientConfig;
use crate::session::{
COOKIE_DOMAIN, Cookie, FORMS_AUTH_COOKIE_NAME, SESSION_COOKIE_NAME, Session,
};
fn test_session() -> Session {
Session::new(
vec![
Cookie::new(SESSION_COOKIE_NAME, "session-123", COOKIE_DOMAIN),
Cookie::new(FORMS_AUTH_COOKIE_NAME, "auth-456", COOKIE_DOMAIN),
],
"xsrf-789",
)
}
fn test_client(server: &mockito::Server) -> Client {
Client::with_config(
test_session(),
ClientConfig {
base_url: server.url(),
..ClientConfig::default()
},
)
.unwrap()
}
fn volume_fixture() -> String {
crate::test_support::read_fixture("volume_response.json")
}
#[test]
fn institutional_volume_columns_returns_10_columns() {
assert_eq!(institutional_volume_columns().len(), 10);
}
#[test]
fn ah_institutional_volume_columns_returns_10_columns() {
assert_eq!(ah_institutional_volume_columns().len(), 10);
}
#[test]
fn total_volume_columns_returns_10_columns() {
assert_eq!(total_volume_columns().len(), 10);
}
#[test]
fn volume_columns_share_leading_and_trailing_layout() {
for columns in [
institutional_volume_columns(),
ah_institutional_volume_columns(),
total_volume_columns(),
] {
assert_eq!(columns[0].data, "Ticker");
assert_eq!(columns[1].data, "Ticker");
assert_eq!(columns[2].data, "Price");
assert_eq!(columns[3].data, "Sector");
assert_eq!(columns[4].data, "Industry");
assert_eq!(columns[8].data, "LastComparibleTradeDate");
assert_eq!(columns[9].data, "LastComparibleTradeDate");
for col in &columns {
assert!(col.searchable);
assert!(col.orderable);
}
}
}
#[test]
fn institutional_volume_columns_middle_fields_match_go_source() {
let columns = institutional_volume_columns();
assert_eq!(columns[5].data, "TotalInstitutionalVolume");
assert_eq!(columns[6].data, "TotalInstitutionalDollars");
assert_eq!(columns[7].data, "TotalInstitutionalDollarsRank");
}
#[test]
fn ah_institutional_volume_columns_middle_fields_match_go_source() {
let columns = ah_institutional_volume_columns();
assert_eq!(columns[5].data, "AHInstitutionalVolume");
assert_eq!(columns[6].data, "AHInstitutionalDollars");
assert_eq!(columns[7].data, "AHInstitutionalDollarsRank");
}
#[test]
fn total_volume_columns_middle_fields_match_go_source() {
let columns = total_volume_columns();
assert_eq!(columns[5].data, "TotalVolume");
assert_eq!(columns[6].data, "TotalDollars");
assert_eq!(columns[7].data, "TotalDollarsRank");
}
#[tokio::test]
async fn get_institutional_volume_returns_fixture_response() {
let mut server = mockito::Server::new_async().await;
let mock = server
.mock("POST", INSTITUTIONAL_VOLUME_PATH)
.with_status(200)
.with_header("content-type", "application/json")
.with_body(volume_fixture())
.create_async()
.await;
let client = test_client(&server);
let response = client
.get_institutional_volume(&VolumeRequest::institutional())
.await
.unwrap();
assert_eq!(response.draw, 1);
assert_eq!(response.records_total, 120);
assert_eq!(response.data.len(), 2);
assert_eq!(response.data[0].ticker.as_deref(), Some("AAPL"));
assert_eq!(response.data[1].ticker.as_deref(), Some("MSFT"));
mock.assert_async().await;
}
#[tokio::test]
async fn get_institutional_volume_limit_respects_limit() {
let mut server = mockito::Server::new_async().await;
server
.mock("POST", INSTITUTIONAL_VOLUME_PATH)
.with_status(200)
.with_header("content-type", "application/json")
.with_body(volume_fixture())
.create_async()
.await;
let client = test_client(&server);
let trades = client
.get_institutional_volume_limit(&VolumeRequest::institutional(), 1)
.await
.unwrap();
assert_eq!(trades.len(), 1);
assert_eq!(trades[0].ticker.as_deref(), Some("AAPL"));
}
#[tokio::test]
async fn get_ah_institutional_volume_returns_fixture_response() {
let mut server = mockito::Server::new_async().await;
let mock = server
.mock("POST", AH_INSTITUTIONAL_VOLUME_PATH)
.with_status(200)
.with_header("content-type", "application/json")
.with_body(volume_fixture())
.create_async()
.await;
let client = test_client(&server);
let response = client
.get_ah_institutional_volume(&VolumeRequest::ah_institutional())
.await
.unwrap();
assert_eq!(response.draw, 1);
assert_eq!(response.records_total, 120);
assert_eq!(response.data.len(), 2);
assert_eq!(response.data[0].ticker.as_deref(), Some("AAPL"));
assert_eq!(response.data[1].ticker.as_deref(), Some("MSFT"));
mock.assert_async().await;
}
#[tokio::test]
async fn get_ah_institutional_volume_limit_respects_limit() {
let mut server = mockito::Server::new_async().await;
server
.mock("POST", AH_INSTITUTIONAL_VOLUME_PATH)
.with_status(200)
.with_header("content-type", "application/json")
.with_body(volume_fixture())
.create_async()
.await;
let client = test_client(&server);
let trades = client
.get_ah_institutional_volume_limit(&VolumeRequest::ah_institutional(), 1)
.await
.unwrap();
assert_eq!(trades.len(), 1);
assert_eq!(trades[0].ticker.as_deref(), Some("AAPL"));
}
#[tokio::test]
async fn get_total_volume_returns_fixture_response() {
let mut server = mockito::Server::new_async().await;
let mock = server
.mock("POST", TOTAL_VOLUME_PATH)
.with_status(200)
.with_header("content-type", "application/json")
.with_body(volume_fixture())
.create_async()
.await;
let client = test_client(&server);
let response = client
.get_total_volume(&VolumeRequest::total())
.await
.unwrap();
assert_eq!(response.draw, 1);
assert_eq!(response.records_total, 120);
assert_eq!(response.data.len(), 2);
assert_eq!(response.data[0].ticker.as_deref(), Some("AAPL"));
assert_eq!(response.data[1].ticker.as_deref(), Some("MSFT"));
mock.assert_async().await;
}
#[tokio::test]
async fn get_total_volume_limit_respects_limit() {
let mut server = mockito::Server::new_async().await;
server
.mock("POST", TOTAL_VOLUME_PATH)
.with_status(200)
.with_header("content-type", "application/json")
.with_body(volume_fixture())
.create_async()
.await;
let client = test_client(&server);
let trades = client
.get_total_volume_limit(&VolumeRequest::total(), 1)
.await
.unwrap();
assert_eq!(trades.len(), 1);
assert_eq!(trades[0].ticker.as_deref(), Some("AAPL"));
}
}