use std::{sync::Arc, vec::IntoIter};
use ahash::{AHashMap, AHashSet};
use datafusion::{
error::Result, logical_expr::expr::Sort, physical_plan::SendableRecordBatchStream, prelude::*,
};
use futures::StreamExt;
use nautilus_core::{UnixNanos, ffi::cvec::CVec};
use nautilus_model::data::{Data, HasTsInit};
use nautilus_serialization::arrow::{
DataStreamingError, DecodeDataFromRecordBatch, EncodeToRecordBatch, WriteStream,
};
use object_store::ObjectStore;
use url::Url;
use super::{
compare::Compare,
kmerge_batch::{EagerStream, ElementBatchIter, KMerge},
};
#[derive(Debug, Default)]
pub struct TsInitComparator;
impl<I> Compare<ElementBatchIter<I, Data>> for TsInitComparator
where
I: Iterator<Item = IntoIter<Data>>,
{
fn compare(
&self,
l: &ElementBatchIter<I, Data>,
r: &ElementBatchIter<I, Data>,
) -> std::cmp::Ordering {
l.item.ts_init().cmp(&r.item.ts_init()).reverse()
}
}
pub type QueryResult = KMerge<EagerStream<std::vec::IntoIter<Data>>, Data, TsInitComparator>;
#[cfg_attr(
feature = "python",
pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.persistence", unsendable)
)]
#[cfg_attr(
feature = "python",
pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.persistence")
)]
pub struct DataBackendSession {
pub chunk_size: usize,
pub runtime: Arc<tokio::runtime::Runtime>,
session_ctx: SessionContext,
batch_streams: Vec<EagerStream<IntoIter<Data>>>,
registered_tables: AHashSet<String>,
}
impl DataBackendSession {
#[must_use]
pub fn new(chunk_size: usize) -> Self {
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let session_cfg = SessionConfig::new()
.set_str("datafusion.optimizer.repartition_file_scans", "false")
.set_str("datafusion.optimizer.prefer_existing_sort", "true");
let session_ctx = SessionContext::new_with_config(session_cfg);
Self {
session_ctx,
batch_streams: Vec::default(),
chunk_size,
runtime: Arc::new(runtime),
registered_tables: AHashSet::new(),
}
}
pub fn register_object_store(&mut self, url: &Url, object_store: Arc<dyn ObjectStore>) {
self.session_ctx.register_object_store(url, object_store);
}
pub fn register_object_store_from_uri(
&mut self,
uri: &str,
storage_options: Option<AHashMap<String, String>>,
) -> anyhow::Result<()> {
let (object_store, _, _) =
crate::parquet::create_object_store_from_path(uri, storage_options)?;
let parsed_uri = Url::parse(uri)?;
if matches!(
parsed_uri.scheme(),
"s3" | "gs" | "gcs" | "az" | "abfs" | "http" | "https"
) {
let base_url = format!(
"{}://{}",
parsed_uri.scheme(),
parsed_uri.host_str().unwrap_or("")
);
let base_parsed_url = Url::parse(&base_url)?;
self.register_object_store(&base_parsed_url, object_store);
}
Ok(())
}
pub fn write_data<T: EncodeToRecordBatch>(
data: &[T],
metadata: &AHashMap<String, String>,
stream: &mut dyn WriteStream,
) -> Result<(), DataStreamingError> {
let metadata: std::collections::HashMap<String, String> = metadata
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
let record_batch = T::encode_batch(&metadata, data)?;
stream.write(&record_batch)?;
Ok(())
}
pub fn add_file<T>(
&mut self,
table_name: &str,
file_path: &str,
sql_query: Option<&str>,
custom_type_name: Option<&str>,
) -> Result<()>
where
T: DecodeDataFromRecordBatch,
{
let is_new_table = !self.registered_tables.contains(table_name);
if is_new_table {
let parquet_options = ParquetReadOptions::<'_> {
skip_metadata: Some(false),
file_sort_order: vec![vec![Sort {
expr: col("ts_init"),
asc: true,
nulls_first: false,
}]],
..Default::default()
};
self.runtime.block_on(self.session_ctx.register_parquet(
table_name,
file_path,
parquet_options,
))?;
self.registered_tables.insert(table_name.to_string());
let default_query = format!("SELECT * FROM {} ORDER BY ts_init", &table_name);
let sql_query = sql_query.unwrap_or(&default_query);
let query = self.runtime.block_on(self.session_ctx.sql(sql_query))?;
let batch_stream = self.runtime.block_on(query.execute_stream())?;
self.add_batch_stream::<T>(batch_stream, custom_type_name.map(String::from));
}
Ok(())
}
fn add_batch_stream<T>(
&mut self,
stream: SendableRecordBatchStream,
custom_type_name: Option<String>,
) where
T: DecodeDataFromRecordBatch,
{
let transform = stream.map(move |result| match result {
Ok(batch) => {
let mut metadata: std::collections::HashMap<String, String> =
batch.schema().metadata().clone();
if let Some(ref tn) = custom_type_name {
metadata.insert("type_name".to_string(), tn.clone());
}
T::decode_data_batch(&metadata, batch).unwrap().into_iter()
}
Err(e) => panic!("Error getting next batch from RecordBatchStream: {e}"),
});
self.batch_streams
.push(EagerStream::from_stream_with_runtime(
transform,
self.runtime.clone(),
));
}
pub fn get_query_result(&mut self) -> QueryResult {
let mut kmerge: KMerge<_, _, _> = KMerge::new(TsInitComparator);
self.batch_streams
.drain(..)
.for_each(|eager_stream| kmerge.push_iter(eager_stream));
kmerge
}
pub fn clear_registered_tables(&mut self) {
self.registered_tables.clear();
self.batch_streams.clear();
let session_cfg = SessionConfig::new()
.set_str("datafusion.optimizer.repartition_file_scans", "false")
.set_str("datafusion.optimizer.prefer_existing_sort", "true");
self.session_ctx = SessionContext::new_with_config(session_cfg);
}
}
#[must_use]
pub fn build_query(
table: &str,
start: Option<UnixNanos>,
end: Option<UnixNanos>,
where_clause: Option<&str>,
) -> String {
let mut conditions = Vec::new();
if let Some(clause) = where_clause {
conditions.push(clause.to_string());
}
if let Some(start_ts) = start {
conditions.push(format!("ts_init >= {start_ts}"));
}
if let Some(end_ts) = end {
conditions.push(format!("ts_init <= {end_ts}"));
}
let mut query = format!("SELECT * FROM {table}");
if !conditions.is_empty() {
query.push_str(" WHERE ");
query.push_str(&conditions.join(" AND "));
}
query.push_str(" ORDER BY ts_init");
query
}
#[cfg_attr(
feature = "python",
pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.persistence", unsendable)
)]
#[cfg_attr(
feature = "python",
pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.persistence")
)]
pub struct DataQueryResult {
pub chunk: Option<CVec>,
pub result: QueryResult,
pub acc: Vec<Data>,
pub size: usize,
}
impl DataQueryResult {
#[must_use]
pub const fn new(result: QueryResult, size: usize) -> Self {
Self {
chunk: None,
result,
acc: Vec::new(),
size,
}
}
pub fn set_chunk(&mut self, data: Vec<Data>) -> CVec {
self.drop_chunk();
let chunk: CVec = data.into();
self.chunk = Some(chunk);
chunk
}
pub fn drop_chunk(&mut self) {
if let Some(CVec { ptr, len, cap }) = self.chunk.take() {
assert!(
len <= cap,
"drop_chunk: len ({len}) > cap ({cap}) - memory corruption or wrong chunk type"
);
assert!(
len == 0 || !ptr.is_null(),
"drop_chunk: null ptr with non-zero len ({len}) - memory corruption"
);
let data: Vec<Data> = unsafe { Vec::from_raw_parts(ptr.cast::<Data>(), len, cap) };
drop(data);
}
}
}
impl Iterator for DataQueryResult {
type Item = Vec<Data>;
fn next(&mut self) -> Option<Self::Item> {
for _ in 0..self.size {
match self.result.next() {
Some(item) => self.acc.push(item),
None => break,
}
}
let mut acc: Vec<Data> = Vec::new();
std::mem::swap(&mut acc, &mut self.acc);
Some(acc)
}
}
impl Drop for DataQueryResult {
fn drop(&mut self) {
self.drop_chunk();
self.result.clear();
}
}