use std::{num::NonZeroU64, path::PathBuf};
use dbn::{encode::AsyncDbnEncoder, Compression, Encoding, SType, Schema, VersionUpgradePolicy};
use futures::{Stream, TryStreamExt};
use reqwest::{header::ACCEPT, RequestBuilder};
use tokio::{
fs::File,
io::{AsyncReadExt, AsyncWriteExt, BufWriter},
};
use tokio_util::{bytes::Bytes, io::StreamReader};
use typed_builder::TypedBuilder;
use crate::Symbols;
use super::{check_http_error, DateTimeRange};
pub use dbn::decode::AsyncDbnDecoder;
#[derive(Debug)]
pub struct TimeseriesClient<'a> {
pub(crate) inner: &'a mut super::Client,
}
impl TimeseriesClient<'_> {
pub async fn get_range(
&mut self,
params: &GetRangeParams,
) -> crate::Result<AsyncDbnDecoder<impl AsyncReadExt>> {
let reader = self
.get_range_impl(
¶ms.dataset,
params.schema,
params.stype_in,
params.stype_out,
¶ms.symbols,
¶ms.date_time_range,
params.limit,
)
.await?;
let mut decoder: AsyncDbnDecoder<_> = AsyncDbnDecoder::with_zstd_buffer(reader).await?;
decoder.set_upgrade_policy(params.upgrade_policy)?;
Ok(decoder)
}
pub async fn get_range_to_file(
&mut self,
params: &GetRangeToFileParams,
) -> crate::Result<AsyncDbnDecoder<impl AsyncReadExt>> {
let reader = self
.get_range_impl(
¶ms.dataset,
params.schema,
params.stype_in,
params.stype_out,
¶ms.symbols,
¶ms.date_time_range,
params.limit,
)
.await?;
let mut http_decoder = AsyncDbnDecoder::with_zstd_buffer(reader).await?;
http_decoder.set_upgrade_policy(params.upgrade_policy)?;
let file = BufWriter::new(File::create(¶ms.path).await?);
let mut encoder = AsyncDbnEncoder::with_zstd(file, http_decoder.metadata()).await?;
while let Some(rec_ref) = http_decoder.decode_record_ref().await? {
encoder.encode_record_ref(rec_ref).await?;
}
encoder.get_mut().shutdown().await?;
Ok(AsyncDbnDecoder::from_zstd_file(¶ms.path).await?)
}
#[allow(clippy::too_many_arguments)] async fn get_range_impl(
&mut self,
dataset: &str,
schema: Schema,
stype_in: SType,
stype_out: SType,
symbols: &Symbols,
date_time_range: &DateTimeRange,
limit: Option<NonZeroU64>,
) -> crate::Result<StreamReader<impl Stream<Item = std::io::Result<Bytes>>, Bytes>> {
let mut form = vec![
("dataset", dataset.to_owned()),
("schema", schema.to_string()),
("encoding", Encoding::Dbn.to_string()),
("compression", Compression::ZStd.to_string()),
("stype_in", stype_in.to_string()),
("stype_out", stype_out.to_string()),
("symbols", symbols.to_api_string()),
];
date_time_range.add_to_form(&mut form);
if let Some(limit) = limit {
form.push(("limit", limit.to_string()));
}
let resp = self
.post("get_range")?
.header(ACCEPT, "application/octet-stream")
.form(&form)
.send()
.await?;
let stream = check_http_error(resp)
.await?
.error_for_status()?
.bytes_stream()
.map_err(std::io::Error::other);
Ok(tokio_util::io::StreamReader::new(stream))
}
fn post(&mut self, slug: &str) -> crate::Result<RequestBuilder> {
self.inner.post(&format!("timeseries.{slug}"))
}
}
#[derive(Debug, Clone, TypedBuilder, PartialEq, Eq)]
pub struct GetRangeParams {
#[builder(setter(transform = |dt: impl ToString| dt.to_string()))]
pub dataset: String,
#[builder(setter(into))]
pub symbols: Symbols,
pub schema: Schema,
#[builder(setter(into))]
pub date_time_range: DateTimeRange,
#[builder(default = SType::RawSymbol)]
pub stype_in: SType,
#[builder(default = SType::InstrumentId)]
pub stype_out: SType,
#[builder(default)]
pub limit: Option<NonZeroU64>,
#[builder(default = VersionUpgradePolicy::UpgradeToV2)]
pub upgrade_policy: VersionUpgradePolicy,
}
#[derive(Debug, Clone, TypedBuilder, PartialEq, Eq)]
pub struct GetRangeToFileParams {
#[builder(setter(transform = |dt: impl ToString| dt.to_string()))]
pub dataset: String,
#[builder(setter(into))]
pub symbols: Symbols,
pub schema: Schema,
#[builder(setter(into))]
pub date_time_range: DateTimeRange,
#[builder(default = SType::RawSymbol)]
pub stype_in: SType,
#[builder(default = SType::InstrumentId)]
pub stype_out: SType,
#[builder(default)]
pub limit: Option<NonZeroU64>,
#[builder(default = VersionUpgradePolicy::UpgradeToV2)]
pub upgrade_policy: VersionUpgradePolicy,
#[builder(default, setter(transform = |p: impl Into<PathBuf>| p.into()))]
pub path: PathBuf,
}
impl From<GetRangeToFileParams> for GetRangeParams {
fn from(value: GetRangeToFileParams) -> Self {
Self {
dataset: value.dataset,
symbols: value.symbols,
schema: value.schema,
date_time_range: value.date_time_range,
stype_in: value.stype_in,
stype_out: value.stype_out,
limit: value.limit,
upgrade_policy: value.upgrade_policy,
}
}
}
impl GetRangeParams {
pub fn with_path(self, path: impl Into<PathBuf>) -> GetRangeToFileParams {
GetRangeToFileParams {
dataset: self.dataset,
symbols: self.symbols,
schema: self.schema,
date_time_range: self.date_time_range,
stype_in: self.stype_in,
stype_out: self.stype_out,
limit: self.limit,
upgrade_policy: self.upgrade_policy,
path: path.into(),
}
}
}
#[cfg(test)]
mod tests {
use dbn::{record::TradeMsg, Dataset};
use reqwest::StatusCode;
use time::macros::datetime;
use wiremock::{
matchers::{basic_auth, method, path},
Mock, MockServer, ResponseTemplate,
};
use super::*;
use crate::{
body_contains,
historical::{HistoricalGateway, API_VERSION},
zst_test_data_path, HistoricalClient,
};
const API_KEY: &str = "test-API";
#[tokio::test]
async fn test_get_range() {
const START: time::OffsetDateTime = datetime!(2023 - 06 - 14 00:00 UTC);
const END: time::OffsetDateTime = datetime!(2023 - 06 - 17 00:00 UTC);
const SCHEMA: Schema = Schema::Trades;
let mock_server = MockServer::start().await;
let bytes = tokio::fs::read(zst_test_data_path(SCHEMA)).await.unwrap();
Mock::given(method("POST"))
.and(basic_auth(API_KEY, ""))
.and(path(format!("/v{API_VERSION}/timeseries.get_range")))
.and(body_contains("dataset", "XNAS.ITCH"))
.and(body_contains("schema", "trades"))
.and(body_contains("symbols", "SPOT%2CAAPL"))
.and(body_contains(
"start",
START.unix_timestamp_nanos().to_string(),
))
.and(body_contains("end", END.unix_timestamp_nanos().to_string()))
.and(body_contains("stype_in", "raw_symbol"))
.and(body_contains("stype_out", "instrument_id"))
.respond_with(ResponseTemplate::new(StatusCode::OK.as_u16()).set_body_bytes(bytes))
.mount(&mock_server)
.await;
let mut target = HistoricalClient::with_url(
mock_server.uri(),
API_KEY.to_owned(),
HistoricalGateway::Bo1,
)
.unwrap();
let mut decoder = target
.timeseries()
.get_range(
&GetRangeParams::builder()
.dataset(dbn::Dataset::XnasItch)
.schema(SCHEMA)
.symbols(vec!["SPOT", "AAPL"])
.date_time_range((START, END))
.build(),
)
.await
.unwrap();
assert_eq!(decoder.metadata().schema.unwrap(), SCHEMA);
decoder.decode_record::<TradeMsg>().await.unwrap().unwrap();
decoder.decode_record::<TradeMsg>().await.unwrap().unwrap();
assert!(decoder.decode_record::<TradeMsg>().await.unwrap().is_none());
}
#[tokio::test]
async fn test_get_range_to_file() {
const START: time::OffsetDateTime = datetime!(2024 - 05 - 17 00:00 UTC);
const END: time::OffsetDateTime = datetime!(2024 - 05 - 18 00:00 UTC);
const SCHEMA: Schema = Schema::Trades;
const DATASET: &str = Dataset::IfeuImpact.as_str();
let mock_server = MockServer::start().await;
let temp_dir = tempfile::TempDir::new().unwrap();
let bytes = tokio::fs::read(zst_test_data_path(SCHEMA)).await.unwrap();
Mock::given(method("POST"))
.and(basic_auth(API_KEY, ""))
.and(path(format!("/v{API_VERSION}/timeseries.get_range")))
.and(body_contains("dataset", DATASET))
.and(body_contains("schema", "trades"))
.and(body_contains("symbols", "BRN.FUT"))
.and(body_contains(
"start",
START.unix_timestamp_nanos().to_string(),
))
.and(body_contains("end", END.unix_timestamp_nanos().to_string()))
.and(body_contains("stype_in", "parent"))
.and(body_contains("stype_out", "instrument_id"))
.respond_with(ResponseTemplate::new(StatusCode::OK.as_u16()).set_body_bytes(bytes))
.mount(&mock_server)
.await;
let mut target = HistoricalClient::with_url(
mock_server.uri(),
API_KEY.to_owned(),
HistoricalGateway::Bo1,
)
.unwrap();
let path = temp_dir.path().join("test.dbn.zst");
let mut decoder = target
.timeseries()
.get_range_to_file(
&GetRangeToFileParams::builder()
.dataset(DATASET)
.schema(SCHEMA)
.symbols(vec!["BRN.FUT"])
.stype_in(SType::Parent)
.date_time_range((START, END))
.path(path.clone())
.build(),
)
.await
.unwrap();
assert_eq!(decoder.metadata().schema.unwrap(), SCHEMA);
decoder.decode_record::<TradeMsg>().await.unwrap().unwrap();
decoder.decode_record::<TradeMsg>().await.unwrap().unwrap();
assert!(decoder.decode_record::<TradeMsg>().await.unwrap().is_none());
}
}