use std::{num::NonZeroU64, path::PathBuf};
use async_compression::tokio::bufread::ZstdDecoder;
use dbn::{
decode::DbnMetadata,
encode::{AsyncDbnEncoder, AsyncEncodeRecordRef},
Compression, Encoding, SType, Schema, VersionUpgradePolicy,
};
use futures::{Stream, TryStreamExt};
use reqwest::{header::ACCEPT, RequestBuilder};
use tokio::{
fs::File,
io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter},
};
use tokio_util::{bytes::Bytes, io::StreamReader};
use tracing::{error, instrument};
use crate::{
historical::{check_warnings, AddToForm, Limit},
Symbols,
};
use super::{check_http_error, DateTimeRange};
pub use dbn::decode::AsyncDbnDecoder;
#[derive(Debug)]
pub struct TimeseriesClient<'a> {
pub(crate) inner: &'a mut super::Client,
}
impl TimeseriesClient<'_> {
#[instrument(name = "timeseries.get_range")]
pub async fn get_range(
&mut self,
params: &GetRangeParams,
) -> crate::Result<AsyncDbnDecoder<impl AsyncReadExt>> {
let reader = self
.get_range_impl(
¶ms.dataset,
params.schema,
params.stype_in,
params.stype_out,
¶ms.symbols,
¶ms.date_time_range,
params.limit,
)
.await?;
#[expect(deprecated)]
let upgrade_policy = params.upgrade_policy.unwrap_or(self.inner.upgrade_policy());
Ok(AsyncDbnDecoder::with_upgrade_policy(zstd_decoder(reader), upgrade_policy).await?)
}
#[instrument(name = "timeseries.get_range_to_file")]
pub async fn get_range_to_file(
&mut self,
params: &GetRangeToFileParams,
) -> crate::Result<AsyncDbnDecoder<ZstdDecoder<BufReader<File>>>> {
let reader = self
.get_range_impl(
¶ms.dataset,
params.schema,
params.stype_in,
params.stype_out,
¶ms.symbols,
¶ms.date_time_range,
params.limit,
)
.await?;
#[expect(deprecated)]
let upgrade_policy = params.upgrade_policy.unwrap_or(self.inner.upgrade_policy());
let mut http_decoder =
AsyncDbnDecoder::with_upgrade_policy(zstd_decoder(reader), upgrade_policy).await?;
let file = BufWriter::new(File::create(¶ms.path).await?);
let mut encoder = AsyncDbnEncoder::with_zstd(file, http_decoder.metadata()).await?;
while let Some(rec_ref) = http_decoder.decode_record_ref().await? {
encoder.encode_record_ref(rec_ref).await?;
}
encoder.get_mut().shutdown().await?;
Ok(AsyncDbnDecoder::with_upgrade_policy(
zstd_decoder(BufReader::new(File::open(¶ms.path).await?)),
VersionUpgradePolicy::AsIs,
)
.await?)
}
#[allow(clippy::too_many_arguments)] async fn get_range_impl(
&mut self,
dataset: &str,
schema: Schema,
stype_in: SType,
stype_out: SType,
symbols: &Symbols,
date_time_range: &DateTimeRange,
limit: Option<NonZeroU64>,
) -> crate::Result<StreamReader<impl Stream<Item = std::io::Result<Bytes>>, Bytes>> {
let form = vec![
("dataset", dataset.to_owned()),
("schema", schema.to_string()),
("encoding", Encoding::Dbn.to_string()),
("compression", Compression::Zstd.to_string()),
("stype_in", stype_in.to_string()),
("stype_out", stype_out.to_string()),
("symbols", symbols.to_api_string()),
]
.add_to_form(date_time_range)
.add_to_form(&Limit(limit));
let resp = self
.post("get_range")?
.header(ACCEPT, "application/octet-stream")
.form(&form)
.send()
.await?;
check_warnings(&resp);
let stream = check_http_error(resp)
.await?
.error_for_status()?
.bytes_stream()
.map_err(|err| {
error!(?err, "Failed reading from HTTP stream");
std::io::Error::other(err)
});
Ok(tokio_util::io::StreamReader::new(stream))
}
fn post(&mut self, slug: &str) -> crate::Result<RequestBuilder> {
self.inner.post(&format!("timeseries.{slug}"))
}
}
#[derive(Debug, Clone, bon::Builder, PartialEq, Eq)]
pub struct GetRangeParams {
#[builder(with = |d: impl ToString| d.to_string())]
pub dataset: String,
#[builder(into)]
pub symbols: Symbols,
pub schema: Schema,
#[builder(into)]
pub date_time_range: DateTimeRange,
#[builder(default = SType::RawSymbol)]
pub stype_in: SType,
#[builder(default = SType::InstrumentId)]
pub stype_out: SType,
pub limit: Option<NonZeroU64>,
#[deprecated(
since = "0.28.0",
note = "Use the upgrade_policy configuration option on HistoricalClient"
)]
pub upgrade_policy: Option<VersionUpgradePolicy>,
}
#[derive(Debug, Clone, bon::Builder, PartialEq, Eq)]
pub struct GetRangeToFileParams {
#[builder(with = |d: impl ToString| d.to_string())]
pub dataset: String,
#[builder(into)]
pub symbols: Symbols,
pub schema: Schema,
#[builder(into)]
pub date_time_range: DateTimeRange,
#[builder(default = SType::RawSymbol)]
pub stype_in: SType,
#[builder(default = SType::InstrumentId)]
pub stype_out: SType,
pub limit: Option<NonZeroU64>,
#[deprecated(
since = "0.28.0",
note = "Use the upgrade_policy configuration option on HistoricalClient"
)]
pub upgrade_policy: Option<VersionUpgradePolicy>,
#[builder(default, into)]
pub path: PathBuf,
}
impl From<GetRangeToFileParams> for GetRangeParams {
fn from(value: GetRangeToFileParams) -> Self {
Self {
dataset: value.dataset,
symbols: value.symbols,
schema: value.schema,
date_time_range: value.date_time_range,
stype_in: value.stype_in,
stype_out: value.stype_out,
limit: value.limit,
#[expect(deprecated)]
upgrade_policy: value.upgrade_policy,
}
}
}
impl GetRangeParams {
pub fn with_path(self, path: impl Into<PathBuf>) -> GetRangeToFileParams {
GetRangeToFileParams {
dataset: self.dataset,
symbols: self.symbols,
schema: self.schema,
date_time_range: self.date_time_range,
stype_in: self.stype_in,
stype_out: self.stype_out,
limit: self.limit,
#[expect(deprecated)]
upgrade_policy: self.upgrade_policy,
path: path.into(),
}
}
}
fn zstd_decoder<R>(reader: R) -> async_compression::tokio::bufread::ZstdDecoder<R>
where
R: tokio::io::AsyncBufReadExt + Unpin,
{
let mut zstd_decoder = async_compression::tokio::bufread::ZstdDecoder::new(reader);
zstd_decoder.multiple_members(true);
zstd_decoder
}
#[cfg(test)]
mod tests {
use dbn::{record::TradeMsg, Dataset};
use reqwest::StatusCode;
use rstest::*;
use time::macros::datetime;
use wiremock::{
matchers::{basic_auth, method, path},
Mock, MockServer, ResponseTemplate,
};
use super::*;
use crate::{
body_contains, historical::test_infra::API_KEY, historical::API_VERSION,
zst_test_data_path, HistoricalClient,
};
fn client(mock_server: &MockServer, upgrade_policy: VersionUpgradePolicy) -> HistoricalClient {
HistoricalClient::builder()
.base_url(mock_server.uri().parse().unwrap())
.key(API_KEY)
.unwrap()
.upgrade_policy(upgrade_policy)
.build()
.unwrap()
}
#[rstest]
#[case(VersionUpgradePolicy::AsIs, 1)]
#[case(VersionUpgradePolicy::UpgradeToV2, 2)]
#[case(VersionUpgradePolicy::UpgradeToV3, 3)]
#[tokio::test]
async fn test_get_range(#[case] upgrade_policy: VersionUpgradePolicy, #[case] exp_version: u8) {
const START: time::OffsetDateTime = datetime!(2023 - 06 - 14 00:00 UTC);
const END: time::OffsetDateTime = datetime!(2023 - 06 - 17 00:00 UTC);
const SCHEMA: Schema = Schema::Trades;
let mock_server = MockServer::start().await;
let bytes = tokio::fs::read(zst_test_data_path(SCHEMA)).await.unwrap();
Mock::given(method("POST"))
.and(basic_auth(API_KEY, ""))
.and(path(format!("/v{API_VERSION}/timeseries.get_range")))
.and(body_contains("dataset", "XNAS.ITCH"))
.and(body_contains("schema", "trades"))
.and(body_contains("symbols", "SPOT%2CAAPL"))
.and(body_contains(
"start",
START.unix_timestamp_nanos().to_string(),
))
.and(body_contains("end", END.unix_timestamp_nanos().to_string()))
.and(body_contains("stype_in", "raw_symbol"))
.and(body_contains("stype_out", "instrument_id"))
.respond_with(ResponseTemplate::new(StatusCode::OK.as_u16()).set_body_bytes(bytes))
.mount(&mock_server)
.await;
let mut target = client(&mock_server, upgrade_policy);
let mut decoder = target
.timeseries()
.get_range(
&GetRangeParams::builder()
.dataset(dbn::Dataset::XnasItch)
.schema(SCHEMA)
.symbols(vec!["SPOT", "AAPL"])
.date_time_range(START..END)
.build(),
)
.await
.unwrap();
let metadata = decoder.metadata();
assert_eq!(metadata.schema.unwrap(), SCHEMA);
assert_eq!(metadata.version, exp_version);
decoder.decode_record::<TradeMsg>().await.unwrap().unwrap();
decoder.decode_record::<TradeMsg>().await.unwrap().unwrap();
assert!(decoder.decode_record::<TradeMsg>().await.unwrap().is_none());
}
#[rstest]
#[case(VersionUpgradePolicy::AsIs, 1)]
#[case(VersionUpgradePolicy::UpgradeToV2, 2)]
#[case(VersionUpgradePolicy::UpgradeToV3, 3)]
#[tokio::test]
async fn test_get_range_to_file(
#[case] upgrade_policy: VersionUpgradePolicy,
#[case] exp_version: u8,
) {
const START: time::OffsetDateTime = datetime!(2024 - 05 - 17 00:00 UTC);
const END: time::OffsetDateTime = datetime!(2024 - 05 - 18 00:00 UTC);
const SCHEMA: Schema = Schema::Trades;
const DATASET: &str = Dataset::IfeuImpact.as_str();
let mock_server = MockServer::start().await;
let temp_dir = tempfile::TempDir::new().unwrap();
let bytes = tokio::fs::read(zst_test_data_path(SCHEMA)).await.unwrap();
Mock::given(method("POST"))
.and(basic_auth(API_KEY, ""))
.and(path(format!("/v{API_VERSION}/timeseries.get_range")))
.and(body_contains("dataset", DATASET))
.and(body_contains("schema", "trades"))
.and(body_contains("symbols", "BRN.FUT"))
.and(body_contains(
"start",
START.unix_timestamp_nanos().to_string(),
))
.and(body_contains("end", END.unix_timestamp_nanos().to_string()))
.and(body_contains("stype_in", "parent"))
.and(body_contains("stype_out", "instrument_id"))
.respond_with(ResponseTemplate::new(StatusCode::OK.as_u16()).set_body_bytes(bytes))
.mount(&mock_server)
.await;
let mut target = client(&mock_server, upgrade_policy);
let path = temp_dir.path().join("test.dbn.zst");
let mut decoder = target
.timeseries()
.get_range_to_file(
&GetRangeToFileParams::builder()
.dataset(DATASET)
.schema(SCHEMA)
.symbols(vec!["BRN.FUT"])
.stype_in(SType::Parent)
.date_time_range(START..END)
.path(path.clone())
.build(),
)
.await
.unwrap();
let metadata = decoder.metadata();
assert_eq!(metadata.schema.unwrap(), SCHEMA);
assert_eq!(metadata.version, exp_version);
decoder.decode_record::<TradeMsg>().await.unwrap().unwrap();
decoder.decode_record::<TradeMsg>().await.unwrap().unwrap();
assert!(decoder.decode_record::<TradeMsg>().await.unwrap().is_none());
}
}