use core::fmt;
use std::{
collections::HashMap,
fmt::Write,
num::NonZeroU64,
path::{Path, PathBuf},
str::FromStr,
};
use dbn::{Compression, Encoding, SType, Schema};
use futures::StreamExt;
use log::info;
use reqwest::RequestBuilder;
use serde::{de, Deserialize, Deserializer};
use time::OffsetDateTime;
use tokio::io::BufWriter;
use typed_builder::TypedBuilder;
use crate::{Error, Symbols};
use super::DateTimeRange;
pub struct BatchClient<'a> {
pub(crate) inner: &'a mut super::Client,
}
impl BatchClient<'_> {
pub async fn submit_job(&mut self, params: &SubmitJobParams) -> crate::Result<BatchJob> {
let mut form = vec![
("dataset", params.dataset.to_string()),
("schema", params.schema.to_string()),
("encoding", params.encoding.to_string()),
("compression", params.compression.to_string()),
("pretty_px", params.pretty_px.to_string()),
("pretty_ts", params.pretty_ts.to_string()),
("map_symbols", params.map_symbols.to_string()),
("split_symbols", params.split_symbols.to_string()),
("split_duration", params.split_duration.to_string()),
(
"packaging",
params
.packaging
.map_or_else(|| "none".to_owned(), |p| p.to_string()),
),
("delivery", params.delivery.to_string()),
("stype_in", params.stype_in.to_string()),
("stype_out", params.stype_out.to_string()),
("symbols", params.symbols.to_api_string()),
];
params.date_time_range.add_to_form(&mut form);
if let Some(split_size) = params.split_size {
form.push(("split_size", split_size.to_string()));
}
if let Some(limit) = params.limit {
form.push(("limit", limit.to_string()));
}
let builder = self.post("submit_job")?.form(&form);
Ok(builder.send().await?.error_for_status()?.json().await?)
}
pub async fn list_jobs(&mut self, params: &ListJobsParams) -> crate::Result<Vec<BatchJob>> {
let mut builder = self.get("list_jobs")?;
if let Some(ref states) = params.states {
let states_str = states.iter().fold(String::new(), |mut acc, s| {
if acc.is_empty() {
s.as_str().to_owned()
} else {
write!(acc, ",{}", s.as_str()).unwrap();
acc
}
});
builder = builder.query(&[("states", states_str)]);
}
if let Some(ref since) = params.since {
builder = builder.query(&[("since", &since.unix_timestamp_nanos().to_string())]);
}
Ok(builder.send().await?.error_for_status()?.json().await?)
}
pub async fn list_files(&mut self, job_id: &str) -> crate::Result<Vec<BatchFileDesc>> {
Ok(self
.get("list_files")?
.query(&[("job_id", job_id)])
.send()
.await?
.error_for_status()?
.json()
.await?)
}
pub async fn download(&mut self, params: &DownloadParams) -> crate::Result<Vec<PathBuf>> {
let job_dir = params.output_dir.join(¶ms.job_id);
if job_dir.exists() {
if !job_dir.is_dir() {
return Err(Error::bad_arg(
"output_dir",
"exists but is not a directory",
));
}
} else {
tokio::fs::create_dir_all(&job_dir).await?;
}
let job_files = self.list_files(¶ms.job_id).await?;
if let Some(filename_to_download) = params.filename_to_download.as_ref() {
let Some(file_desc) = job_files
.iter()
.find(|file| file.filename == *filename_to_download)
else {
return Err(Error::bad_arg(
"filename_to_download",
"not found for batch job",
));
};
let output_path = job_dir.join(filename_to_download);
let https_url = file_desc
.urls
.get("https")
.ok_or_else(|| Error::internal("Missing https URL for batch file"))?;
self.download_file(https_url, &output_path).await?;
Ok(vec![output_path])
} else {
let mut paths = Vec::new();
for file_desc in job_files.iter() {
let output_path = params
.output_dir
.join(¶ms.job_id)
.join(&file_desc.filename);
let https_url = file_desc
.urls
.get("https")
.ok_or_else(|| Error::internal("Missing https URL for batch file"))?;
self.download_file(https_url, &output_path).await?;
paths.push(output_path);
}
Ok(paths)
}
}
async fn download_file(&mut self, url: &str, path: impl AsRef<Path>) -> crate::Result<()> {
let url = reqwest::Url::parse(url)
.map_err(|e| Error::internal(format!("Unable to parse URL: {e:?}")))?;
let mut stream = self
.inner
.get_with_path(url.path())?
.send()
.await?
.error_for_status()?
.bytes_stream();
info!("Saving {url} to {}", path.as_ref().display());
let mut output = BufWriter::new(
tokio::fs::OpenOptions::new()
.create(true)
.truncate(true)
.write(true)
.open(path)
.await?,
);
while let Some(chunk) = stream.next().await {
tokio::io::copy(&mut chunk?.as_ref(), &mut output).await?;
}
Ok(())
}
const PATH_PREFIX: &str = "batch";
fn get(&mut self, slug: &str) -> crate::Result<RequestBuilder> {
self.inner.get(&format!("{}.{slug}", Self::PATH_PREFIX))
}
fn post(&mut self, slug: &str) -> crate::Result<RequestBuilder> {
self.inner.post(&format!("{}.{slug}", Self::PATH_PREFIX))
}
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum SplitDuration {
#[default]
Day,
Week,
Month,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Packaging {
Zip,
Tar,
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum Delivery {
#[default]
Download,
S3,
Disk,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum JobState {
Received,
Queued,
Processing,
Done,
Expired,
}
#[derive(Debug, Clone, TypedBuilder)]
pub struct SubmitJobParams {
#[builder(setter(transform = |dt: impl ToString| dt.to_string()))]
pub dataset: String,
#[builder(setter(into))]
pub symbols: Symbols,
pub schema: Schema,
#[builder(setter(into))]
pub date_time_range: DateTimeRange,
#[builder(default = Encoding::Dbn)]
pub encoding: Encoding,
#[builder(default = Compression::ZStd)]
pub compression: Compression,
#[builder(default)]
pub pretty_px: bool,
#[builder(default)]
pub pretty_ts: bool,
#[builder(default)]
pub map_symbols: bool,
#[builder(default)]
pub split_symbols: bool,
#[builder(default)]
pub split_duration: SplitDuration,
#[builder(default, setter(strip_option))]
pub split_size: Option<NonZeroU64>,
#[builder(default, setter(strip_option))]
pub packaging: Option<Packaging>,
#[builder(default)]
pub delivery: Delivery,
#[builder(default = SType::RawSymbol)]
pub stype_in: SType,
#[builder(default = SType::InstrumentId)]
pub stype_out: SType,
#[builder(default)]
pub limit: Option<NonZeroU64>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct BatchJob {
pub id: String,
pub user_id: Option<String>,
pub bill_id: Option<String>,
pub cost_usd: Option<f64>,
pub dataset: String,
pub symbols: Symbols,
pub stype_in: SType,
pub stype_out: SType,
pub schema: Schema,
#[serde(deserialize_with = "deserialize_date_time")]
pub start: OffsetDateTime,
#[serde(deserialize_with = "deserialize_date_time")]
pub end: OffsetDateTime,
pub limit: Option<NonZeroU64>,
pub encoding: Encoding,
#[serde(deserialize_with = "deserialize_compression")]
pub compression: Compression,
pub pretty_px: bool,
pub pretty_ts: bool,
pub map_symbols: bool,
pub split_symbols: bool,
pub split_duration: SplitDuration,
pub split_size: Option<NonZeroU64>,
pub packaging: Option<Packaging>,
pub delivery: Delivery,
pub record_count: Option<u64>,
pub billed_size: Option<u64>,
pub actual_size: Option<u64>,
pub package_size: Option<u64>,
pub state: JobState,
#[serde(deserialize_with = "deserialize_date_time")]
pub ts_received: OffsetDateTime,
#[serde(deserialize_with = "deserialize_opt_date_time")]
pub ts_queued: Option<OffsetDateTime>,
#[serde(deserialize_with = "deserialize_opt_date_time")]
pub ts_process_start: Option<OffsetDateTime>,
#[serde(deserialize_with = "deserialize_opt_date_time")]
pub ts_process_done: Option<OffsetDateTime>,
#[serde(deserialize_with = "deserialize_opt_date_time")]
pub ts_expiration: Option<OffsetDateTime>,
}
#[derive(Debug, Clone, Default, TypedBuilder)]
pub struct ListJobsParams {
#[builder(default, setter(strip_option))]
pub states: Option<Vec<JobState>>,
#[builder(default, setter(strip_option))]
pub since: Option<OffsetDateTime>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct BatchFileDesc {
pub filename: String,
pub size: u64,
pub hash: String,
pub urls: HashMap<String, String>,
}
#[derive(Debug, Clone, TypedBuilder)]
pub struct DownloadParams {
pub output_dir: PathBuf,
pub job_id: String,
#[builder(default, setter(strip_option))]
pub filename_to_download: Option<String>,
}
const LEGACY_DATE_TIME_FORMAT: &[time::format_description::FormatItem<'static>] =
time::macros::format_description!("[year]-[month]-[day] [hour]:[minute]:[second][optional [.[subsecond digits:6]]][offset_hour]:[offset_minute]");
const DATE_TIME_FORMAT: &[time::format_description::FormatItem<'static>] = time::macros::format_description!(
"[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:9]Z"
);
fn deserialize_date_time<'de, D: serde::Deserializer<'de>>(
deserializer: D,
) -> Result<time::OffsetDateTime, D::Error> {
let dt_str = String::deserialize(deserializer)?;
time::PrimitiveDateTime::parse(&dt_str, DATE_TIME_FORMAT)
.map(|dt| dt.assume_utc())
.or_else(|_| time::OffsetDateTime::parse(&dt_str, LEGACY_DATE_TIME_FORMAT))
.map_err(serde::de::Error::custom)
}
fn deserialize_opt_date_time<'de, D: serde::Deserializer<'de>>(
deserializer: D,
) -> Result<Option<time::OffsetDateTime>, D::Error> {
if let Some(dt_str) = Option::<String>::deserialize(deserializer)? {
time::PrimitiveDateTime::parse(&dt_str, DATE_TIME_FORMAT)
.map(|dt| dt.assume_utc())
.or_else(|_| time::OffsetDateTime::parse(&dt_str, LEGACY_DATE_TIME_FORMAT))
.map(Some)
.map_err(serde::de::Error::custom)
} else {
Ok(None)
}
}
impl SplitDuration {
pub const fn as_str(&self) -> &'static str {
match self {
SplitDuration::Day => "day",
SplitDuration::Week => "week",
SplitDuration::Month => "month",
}
}
}
impl fmt::Display for SplitDuration {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
impl FromStr for SplitDuration {
type Err = crate::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"day" => Ok(SplitDuration::Day),
"week" => Ok(SplitDuration::Week),
"month" => Ok(SplitDuration::Month),
_ => Err(crate::Error::bad_arg(
"s",
format!(
"{s} does not correspond with any {} variant",
std::any::type_name::<Self>()
),
)),
}
}
}
impl<'de> Deserialize<'de> for SplitDuration {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let str = String::deserialize(deserializer)?;
FromStr::from_str(&str).map_err(de::Error::custom)
}
}
impl Packaging {
pub const fn as_str(&self) -> &'static str {
match self {
Packaging::Zip => "zip",
Packaging::Tar => "tar",
}
}
}
impl fmt::Display for Packaging {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
impl FromStr for Packaging {
type Err = crate::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"zip" => Ok(Packaging::Zip),
"tar" => Ok(Packaging::Tar),
_ => Err(crate::Error::bad_arg(
"s",
format!(
"{s} does not correspond with any {} variant",
std::any::type_name::<Self>()
),
)),
}
}
}
impl<'de> Deserialize<'de> for Packaging {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let str = String::deserialize(deserializer)?;
FromStr::from_str(&str).map_err(de::Error::custom)
}
}
impl Delivery {
pub const fn as_str(&self) -> &'static str {
match self {
Delivery::Download => "download",
Delivery::S3 => "s3",
Delivery::Disk => "disk",
}
}
}
impl fmt::Display for Delivery {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
impl FromStr for Delivery {
type Err = crate::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"download" => Ok(Delivery::Download),
"s3" => Ok(Delivery::S3),
"disk" => Ok(Delivery::Disk),
_ => Err(crate::Error::bad_arg(
"s",
format!(
"{s} does not correspond with any {} variant",
std::any::type_name::<Self>()
),
)),
}
}
}
impl<'de> Deserialize<'de> for Delivery {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let str = String::deserialize(deserializer)?;
FromStr::from_str(&str).map_err(de::Error::custom)
}
}
impl JobState {
pub const fn as_str(&self) -> &'static str {
match self {
JobState::Received => "received",
JobState::Queued => "queued",
JobState::Processing => "processing",
JobState::Done => "done",
JobState::Expired => "expired",
}
}
}
impl fmt::Display for JobState {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
impl FromStr for JobState {
type Err = crate::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"received" => Ok(JobState::Received),
"queued" => Ok(JobState::Queued),
"processing" => Ok(JobState::Processing),
"done" => Ok(JobState::Done),
"expired" => Ok(JobState::Expired),
_ => Err(crate::Error::bad_arg(
"s",
format!(
"{s} does not correspond with any {} variant",
std::any::type_name::<Self>()
),
)),
}
}
}
impl<'de> Deserialize<'de> for JobState {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let str = String::deserialize(deserializer)?;
FromStr::from_str(&str).map_err(de::Error::custom)
}
}
fn deserialize_compression<'de, D: serde::Deserializer<'de>>(
deserializer: D,
) -> Result<Compression, D::Error> {
let opt = Option::<Compression>::deserialize(deserializer)?;
Ok(opt.unwrap_or(Compression::None))
}
#[cfg(test)]
mod tests {
use reqwest::StatusCode;
use serde_json::json;
use time::macros::datetime;
use wiremock::{
matchers::{basic_auth, method, path, query_param_is_missing},
Mock, MockServer, ResponseTemplate,
};
use super::*;
use crate::{
body_contains,
historical::{HistoricalGateway, API_VERSION},
HistoricalClient,
};
const API_KEY: &str = "test-batch";
#[tokio::test]
async fn test_submit_job() -> crate::Result<()> {
const START: time::OffsetDateTime = datetime!(2023 - 06 - 14 00:00 UTC);
const END: time::OffsetDateTime = datetime!(2023 - 06 - 17 00:00 UTC);
const SCHEMA: Schema = Schema::Trades;
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.and(basic_auth(API_KEY, ""))
.and(path(format!("/v{API_VERSION}/batch.submit_job")))
.and(body_contains("dataset", "XNAS.ITCH"))
.and(body_contains("schema", "trades"))
.and(body_contains("symbols", "TSLA"))
.and(body_contains(
"start",
START.unix_timestamp_nanos().to_string(),
))
.and(body_contains("encoding", "dbn"))
.and(body_contains("compression", "zstd"))
.and(body_contains("map_symbols", "false"))
.and(body_contains("end", END.unix_timestamp_nanos().to_string()))
.and(body_contains("stype_in", "raw_symbol"))
.and(body_contains("stype_out", "instrument_id"))
.respond_with(ResponseTemplate::new(StatusCode::OK).set_body_json(json!({
"id": "123",
"user_id": "test_user",
"bill_id": "345",
"cost_usd": 10.50,
"dataset": "XNAS.ITCH",
"symbols": ["TSLA"],
"stype_in": "raw_symbol",
"stype_out": "instrument_id",
"schema": SCHEMA.as_str(),
"start": "2023-06-14T00:00:00.000000000Z",
"end": "2023-06-17 00:00:00.000000+00:00",
"limit": null,
"encoding": "dbn",
"compression": "zstd",
"pretty_px": false,
"pretty_ts": false,
"map_symbols": false,
"split_symbols": false,
"split_duration": "day",
"split_size": null,
"packaging": null,
"delivery": "download",
"state": "queued",
"ts_received": "2023-07-19T23:00:04.095538123Z",
"ts_queued": null,
"ts_process_start": null,
"ts_process_done": null,
"ts_expiration": null
})))
.mount(&mock_server)
.await;
let mut target = HistoricalClient::with_url(
mock_server.uri(),
API_KEY.to_owned(),
HistoricalGateway::Bo1,
)?;
let job_desc = target
.batch()
.submit_job(
&SubmitJobParams::builder()
.dataset(dbn::datasets::XNAS_ITCH)
.schema(SCHEMA)
.symbols("TSLA")
.date_time_range((START, END))
.build(),
)
.await?;
assert_eq!(job_desc.dataset, dbn::datasets::XNAS_ITCH);
Ok(())
}
#[tokio::test]
async fn test_list_jobs() -> crate::Result<()> {
const SCHEMA: Schema = Schema::Trades;
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(basic_auth(API_KEY, ""))
.and(path(format!("/v{API_VERSION}/batch.list_jobs")))
.and(query_param_is_missing("states"))
.and(query_param_is_missing("since"))
.respond_with(ResponseTemplate::new(StatusCode::OK).set_body_json(json!([{
"id": "123",
"user_id": "test_user",
"bill_id": "345",
"cost_usd": 10.50,
"dataset": "XNAS.ITCH",
"symbols": "TSLA",
"stype_in": "raw_symbol",
"stype_out": "instrument_id",
"schema": SCHEMA.as_str(),
"start": "2023-06-14 00:00:00+00:00",
"end": "2023-06-17T00:00:00.012345678Z",
"limit": null,
"encoding": "json",
"compression": "zstd",
"pretty_px": true,
"pretty_ts": false,
"map_symbols": true,
"split_symbols": false,
"split_duration": "day",
"split_size": null,
"packaging": null,
"delivery": "download",
"state": "processing",
"ts_received": "2023-07-19 23:00:04.095538+00:00",
"ts_queued": "2023-07-19T23:00:08.095538123Z",
"ts_process_start": "2023-07-19 23:01:04.000000+00:00",
"ts_process_done": null,
"ts_expiration": null
}])))
.mount(&mock_server)
.await;
let mut target = HistoricalClient::with_url(
mock_server.uri(),
API_KEY.to_owned(),
HistoricalGateway::Bo1,
)?;
let job_descs = target.batch().list_jobs(&ListJobsParams::default()).await?;
assert_eq!(job_descs.len(), 1);
let job_desc = &job_descs[0];
assert_eq!(
job_desc.ts_queued.unwrap(),
datetime!(2023-07-19 23:00:08.095538123 UTC)
);
assert_eq!(
job_desc.ts_process_start.unwrap(),
datetime!(2023-07-19 23:01:04 UTC)
);
assert_eq!(job_desc.encoding, Encoding::Json);
assert!(job_desc.pretty_px);
assert!(!job_desc.pretty_ts);
assert!(job_desc.map_symbols);
Ok(())
}
#[test]
fn test_deserialize_compression() {
#[derive(serde::Deserialize)]
struct Test {
#[serde(deserialize_with = "deserialize_compression")]
compression: Compression,
}
const JSON: &str =
r#"[{"compression":null}, {"compression":"none"}, {"compression":"zstd"}]"#;
let res: Vec<Test> = serde_json::from_str(JSON).unwrap();
assert_eq!(
res.into_iter().map(|t| t.compression).collect::<Vec<_>>(),
vec![Compression::None, Compression::None, Compression::ZStd]
);
}
}