nautilus-persistence 0.55.0

Data persistence and storage for the Nautilus trading engine
Documentation
// -------------------------------------------------------------------------------------------------
//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
//  https://nautechsystems.io
//
//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
//  You may not use this file except in compliance with the License.
//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
//
//  Unless required by applicable law or agreed to in writing, software
//  distributed under the License is distributed on an "AS IS" BASIS,
//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//  See the License for the specific language governing permissions and
//  limitations under the License.
// -------------------------------------------------------------------------------------------------

use std::{sync::Arc, vec::IntoIter};

use ahash::{AHashMap, AHashSet};
use datafusion::{
    error::Result, logical_expr::expr::Sort, physical_plan::SendableRecordBatchStream, prelude::*,
};
use futures::StreamExt;
use nautilus_core::{UnixNanos, ffi::cvec::CVec};
use nautilus_model::data::{Data, HasTsInit};
use nautilus_serialization::arrow::{
    DataStreamingError, DecodeDataFromRecordBatch, EncodeToRecordBatch, WriteStream,
};
use object_store::ObjectStore;
use url::Url;

use super::{
    compare::Compare,
    kmerge_batch::{EagerStream, ElementBatchIter, KMerge},
};

#[derive(Debug, Default)]
pub struct TsInitComparator;

impl<I> Compare<ElementBatchIter<I, Data>> for TsInitComparator
where
    I: Iterator<Item = IntoIter<Data>>,
{
    fn compare(
        &self,
        l: &ElementBatchIter<I, Data>,
        r: &ElementBatchIter<I, Data>,
    ) -> std::cmp::Ordering {
        // Max heap ordering must be reversed
        l.item.ts_init().cmp(&r.item.ts_init()).reverse()
    }
}

pub type QueryResult = KMerge<EagerStream<std::vec::IntoIter<Data>>, Data, TsInitComparator>;

/// Provides a DataFusion session and registers DataFusion queries.
///
/// The session is used to register data sources and make queries on them. A
/// query returns a Chunk of Arrow records. It is decoded and converted into
/// a Vec of data by types that implement [`DecodeDataFromRecordBatch`].
#[cfg_attr(
    feature = "python",
    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.persistence", unsendable)
)]
#[cfg_attr(
    feature = "python",
    pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.persistence")
)]
pub struct DataBackendSession {
    pub chunk_size: usize,
    pub runtime: Arc<tokio::runtime::Runtime>,
    session_ctx: SessionContext,
    batch_streams: Vec<EagerStream<IntoIter<Data>>>,
    registered_tables: AHashSet<String>,
}

impl DataBackendSession {
    /// Creates a new [`DataBackendSession`] instance.
    #[must_use]
    pub fn new(chunk_size: usize) -> Self {
        let runtime = tokio::runtime::Builder::new_multi_thread()
            .enable_all()
            .build()
            .unwrap();
        let session_cfg = SessionConfig::new()
            .set_str("datafusion.optimizer.repartition_file_scans", "false")
            .set_str("datafusion.optimizer.prefer_existing_sort", "true");
        let session_ctx = SessionContext::new_with_config(session_cfg);
        Self {
            session_ctx,
            batch_streams: Vec::default(),
            chunk_size,
            runtime: Arc::new(runtime),
            registered_tables: AHashSet::new(),
        }
    }

    /// Register an object store with the session context
    pub fn register_object_store(&mut self, url: &Url, object_store: Arc<dyn ObjectStore>) {
        self.session_ctx.register_object_store(url, object_store);
    }

    /// Register an object store with the session context from a URI with optional storage options
    pub fn register_object_store_from_uri(
        &mut self,
        uri: &str,
        storage_options: Option<AHashMap<String, String>>,
    ) -> anyhow::Result<()> {
        // Create object store from URI using the Rust implementation
        let (object_store, _, _) =
            crate::parquet::create_object_store_from_path(uri, storage_options)?;

        // Parse the URI to get the base URL for registration
        let parsed_uri = Url::parse(uri)?;

        // Register the object store with the session
        if matches!(
            parsed_uri.scheme(),
            "s3" | "gs" | "gcs" | "az" | "abfs" | "http" | "https"
        ) {
            // For cloud storage, register with the base URL (scheme + netloc)
            let base_url = format!(
                "{}://{}",
                parsed_uri.scheme(),
                parsed_uri.host_str().unwrap_or("")
            );
            let base_parsed_url = Url::parse(&base_url)?;
            self.register_object_store(&base_parsed_url, object_store);
        }

        Ok(())
    }

    pub fn write_data<T: EncodeToRecordBatch>(
        data: &[T],
        metadata: &AHashMap<String, String>,
        stream: &mut dyn WriteStream,
    ) -> Result<(), DataStreamingError> {
        // Convert AHashMap to HashMap for Arrow compatibility
        let metadata: std::collections::HashMap<String, String> = metadata
            .iter()
            .map(|(k, v)| (k.clone(), v.clone()))
            .collect();
        let record_batch = T::encode_batch(&metadata, data)?;
        stream.write(&record_batch)?;
        Ok(())
    }

    /// Registers a Parquet file and adds a batch stream for decoding.
    ///
    /// The caller must specify `T` to indicate the kind of data expected. `table_name` is
    /// the logical name for queries; `file_path` is the Parquet path; `sql_query` defaults
    /// to `SELECT * FROM {table_name} ORDER BY ts_init` if `None`.
    ///
    /// When `custom_type_name` is `Some`, it is merged into each batch's schema metadata
    /// before decoding (as `type_name`). Use this for custom data when Parquet/DataFusion
    /// does not preserve schema metadata so the decoder can look up the type in the registry.
    ///
    /// The file data must be ordered by the `ts_init` in ascending order for this
    /// to work correctly.
    pub fn add_file<T>(
        &mut self,
        table_name: &str,
        file_path: &str,
        sql_query: Option<&str>,
        custom_type_name: Option<&str>,
    ) -> Result<()>
    where
        T: DecodeDataFromRecordBatch,
    {
        // Check if table is already registered to avoid duplicates
        let is_new_table = !self.registered_tables.contains(table_name);

        if is_new_table {
            // Register the table only if it doesn't exist
            let parquet_options = ParquetReadOptions::<'_> {
                skip_metadata: Some(false),
                file_sort_order: vec![vec![Sort {
                    expr: col("ts_init"),
                    asc: true,
                    nulls_first: false,
                }]],
                ..Default::default()
            };
            self.runtime.block_on(self.session_ctx.register_parquet(
                table_name,
                file_path,
                parquet_options,
            ))?;

            self.registered_tables.insert(table_name.to_string());

            // Only add batch stream for newly registered tables to avoid duplicates
            let default_query = format!("SELECT * FROM {} ORDER BY ts_init", &table_name);
            let sql_query = sql_query.unwrap_or(&default_query);
            let query = self.runtime.block_on(self.session_ctx.sql(sql_query))?;
            let batch_stream = self.runtime.block_on(query.execute_stream())?;
            self.add_batch_stream::<T>(batch_stream, custom_type_name.map(String::from));
        }

        Ok(())
    }

    fn add_batch_stream<T>(
        &mut self,
        stream: SendableRecordBatchStream,
        custom_type_name: Option<String>,
    ) where
        T: DecodeDataFromRecordBatch,
    {
        let transform = stream.map(move |result| match result {
            Ok(batch) => {
                let mut metadata: std::collections::HashMap<String, String> =
                    batch.schema().metadata().clone();

                if let Some(ref tn) = custom_type_name {
                    metadata.insert("type_name".to_string(), tn.clone());
                }
                T::decode_data_batch(&metadata, batch).unwrap().into_iter()
            }
            Err(e) => panic!("Error getting next batch from RecordBatchStream: {e}"),
        });

        self.batch_streams
            .push(EagerStream::from_stream_with_runtime(
                transform,
                self.runtime.clone(),
            ));
    }

    // Consumes the registered queries and returns a [`QueryResult].
    // Passes the output of the query though the a KMerge which sorts the
    // queries in ascending order of `ts_init`.
    // QueryResult is an iterator that return Vec<Data>.
    pub fn get_query_result(&mut self) -> QueryResult {
        let mut kmerge: KMerge<_, _, _> = KMerge::new(TsInitComparator);

        self.batch_streams
            .drain(..)
            .for_each(|eager_stream| kmerge.push_iter(eager_stream));

        kmerge
    }

    /// Clears all registered tables and batch streams.
    ///
    /// This is useful when the underlying files have changed and we need to
    /// re-register tables with updated data.
    pub fn clear_registered_tables(&mut self) {
        self.registered_tables.clear();
        self.batch_streams.clear();

        // Create a new session context to completely reset the DataFusion state
        let session_cfg = SessionConfig::new()
            .set_str("datafusion.optimizer.repartition_file_scans", "false")
            .set_str("datafusion.optimizer.prefer_existing_sort", "true");
        self.session_ctx = SessionContext::new_with_config(session_cfg);
    }
}

#[must_use]
pub fn build_query(
    table: &str,
    start: Option<UnixNanos>,
    end: Option<UnixNanos>,
    where_clause: Option<&str>,
) -> String {
    let mut conditions = Vec::new();

    // Add where clause if provided
    if let Some(clause) = where_clause {
        conditions.push(clause.to_string());
    }

    // Add start condition if provided
    if let Some(start_ts) = start {
        conditions.push(format!("ts_init >= {start_ts}"));
    }

    // Add end condition if provided
    if let Some(end_ts) = end {
        conditions.push(format!("ts_init <= {end_ts}"));
    }

    // Build base query
    let mut query = format!("SELECT * FROM {table}");

    // Add WHERE clause if there are conditions
    if !conditions.is_empty() {
        query.push_str(" WHERE ");
        query.push_str(&conditions.join(" AND "));
    }

    // Add ORDER BY clause
    query.push_str(" ORDER BY ts_init");

    query
}

#[cfg_attr(
    feature = "python",
    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.persistence", unsendable)
)]
#[cfg_attr(
    feature = "python",
    pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.persistence")
)]
pub struct DataQueryResult {
    pub chunk: Option<CVec>,
    pub result: QueryResult,
    pub acc: Vec<Data>,
    pub size: usize,
}

impl DataQueryResult {
    /// Creates a new [`DataQueryResult`] instance.
    #[must_use]
    pub const fn new(result: QueryResult, size: usize) -> Self {
        Self {
            chunk: None,
            result,
            acc: Vec::new(),
            size,
        }
    }

    /// Set new `CVec` backed chunk from data
    ///
    /// It also drops previously allocated chunk
    pub fn set_chunk(&mut self, data: Vec<Data>) -> CVec {
        self.drop_chunk();

        let chunk: CVec = data.into();
        self.chunk = Some(chunk);
        chunk
    }

    /// Chunks generated by iteration must be dropped after use, otherwise
    /// it will leak memory. Current chunk is held by the reader,
    /// drop if exists and reset the field.
    pub fn drop_chunk(&mut self) {
        if let Some(CVec { ptr, len, cap }) = self.chunk.take() {
            assert!(
                len <= cap,
                "drop_chunk: len ({len}) > cap ({cap}) - memory corruption or wrong chunk type"
            );
            assert!(
                len == 0 || !ptr.is_null(),
                "drop_chunk: null ptr with non-zero len ({len}) - memory corruption"
            );

            // SAFETY: `ptr`, `len`, and `cap` originate from a valid `CVec` and the
            // assertions above verify the invariants required by `Vec::from_raw_parts`.
            let data: Vec<Data> = unsafe { Vec::from_raw_parts(ptr.cast::<Data>(), len, cap) };
            drop(data);
        }
    }
}

impl Iterator for DataQueryResult {
    type Item = Vec<Data>;

    fn next(&mut self) -> Option<Self::Item> {
        for _ in 0..self.size {
            match self.result.next() {
                Some(item) => self.acc.push(item),
                None => break,
            }
        }

        // TODO: consider using drain here if perf is unchanged
        // Some(self.acc.drain(0..).collect())
        let mut acc: Vec<Data> = Vec::new();
        std::mem::swap(&mut acc, &mut self.acc);
        Some(acc)
    }
}

impl Drop for DataQueryResult {
    fn drop(&mut self) {
        self.drop_chunk();
        self.result.clear();
    }
}