use std::path::PathBuf;
use std::sync::Arc;
use super::{DataSource, EnrichedHistory, StorageRoot, StreamKind, TimedEvent};
use super::rest_fetcher::RestFetcher;
use crate::core::types::Bar;
pub struct EnrichedDataLoader {
pub source: DataSource,
rest_fetcher: Option<Arc<dyn RestFetcher>>,
}
impl EnrichedDataLoader {
pub fn new(source: DataSource) -> Self {
Self { source, rest_fetcher: None }
}
pub fn with_rest_fetcher(mut self, fetcher: Arc<dyn RestFetcher>) -> Self {
self.rest_fetcher = Some(fetcher);
self
}
pub async fn load(
&self,
symbol: &str,
bars: Vec<Bar>,
streams: &[StreamKind],
) -> std::io::Result<EnrichedHistory> {
let (from_ts, to_ts) = if bars.is_empty() {
(0i64, i64::MAX)
} else {
(bars.first().unwrap().time, bars.last().unwrap().time)
};
let mut events: Vec<TimedEvent> =
bars.iter().cloned().map(TimedEvent::Bar).collect();
for &kind in streams {
if kind == StreamKind::Bar {
continue;
}
let mut stream_events =
self.load_stream_events(&self.source, symbol, kind, from_ts, to_ts).await?;
events.append(&mut stream_events);
}
events.sort_by_key(|e| e.timestamp_ms());
Ok(EnrichedHistory::new(bars, events))
}
async fn load_stream_events(
&self,
source: &DataSource,
symbol: &str,
kind: StreamKind,
from_ts: i64,
to_ts: i64,
) -> std::io::Result<Vec<TimedEvent>> {
match source {
DataSource::Binary { storage_root } => {
StorageRoot::new(storage_root.clone()).read_range(symbol, kind, from_ts, to_ts)
}
DataSource::Json { storage_root } => {
self.read_json(storage_root, symbol, kind, from_ts, to_ts)
}
DataSource::Rest { exchange, account_type } => {
if let Some(fetcher) = &self.rest_fetcher {
fetcher
.fetch(*exchange, *account_type, symbol, kind, from_ts, to_ts)
.await
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
} else {
Err(std::io::Error::new(
std::io::ErrorKind::Unsupported,
"Rest source requires RestFetcher (use with_rest_fetcher)",
))
}
}
DataSource::Mixed { per_stream } => {
if let Some(sub) = per_stream.get(&kind) {
Box::pin(self.load_stream_events(sub, symbol, kind, from_ts, to_ts)).await
} else {
Err(std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("Mixed source has no entry for {:?}", kind),
))
}
}
}
}
fn read_json(
&self,
root: &PathBuf,
symbol: &str,
kind: StreamKind,
from_ts: i64,
to_ts: i64,
) -> std::io::Result<Vec<TimedEvent>> {
let path = root.join(symbol).join(format!("{}.json", kind.as_str()));
if !path.exists() {
return Ok(Vec::new());
}
let content = std::fs::read_to_string(&path)?;
let all: Vec<TimedEvent> = serde_json::from_str(&content)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
Ok(all
.into_iter()
.filter(|e| {
let ts = e.timestamp_ms();
ts >= from_ts && ts <= to_ts
})
.collect())
}
}
#[cfg(test)]
mod tests {
use super::{EnrichedDataLoader, RestFetcher};
use crate::core::types::{Bar, FundingRate, OpenInterest};
use crate::data_loader::{DataSource, StorageRoot, StreamKind, TimedEvent};
use async_trait::async_trait;
use digdigdig3::{AccountType, ExchangeId};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
fn make_bar(t: i64) -> Bar {
Bar::new(t, 1.0, 2.0, 0.5, 1.5, 100.0)
}
fn make_funding_event(ts: i64) -> TimedEvent {
TimedEvent::Funding(FundingRate {
rate: 0.0001,
next_funding_time: None,
timestamp: ts,
})
}
fn make_oi_event(ts: i64) -> TimedEvent {
TimedEvent::OpenInterest(OpenInterest {
open_interest: 1000.0,
open_interest_value: None,
timestamp: ts,
})
}
fn tempdir(tag: &str) -> PathBuf {
let mut p = std::env::temp_dir();
p.push(format!("mli_loader_test_{}_{}", std::process::id(), tag));
std::fs::create_dir_all(&p).unwrap();
p
}
#[tokio::test]
async fn binary_bars_only_load() {
let dir = tempdir("binary_bars_only");
let loader = EnrichedDataLoader::new(DataSource::Binary {
storage_root: dir.clone(),
});
let bars: Vec<Bar> = (0..5).map(|i| make_bar(i * 1_000)).collect();
let history = loader.load("BTCUSDT", bars, &[]).await.unwrap();
assert_eq!(history.bar_count(), 5);
assert_eq!(history.event_count(), 5);
}
#[tokio::test]
async fn binary_multi_stream_sorted_order() {
let dir = tempdir("binary_multi_stream");
let storage = StorageRoot::new(&dir);
let funding_timestamps = [500i64, 1500, 2500, 3500, 4500, 5500, 6500, 7500, 8500, 9500];
for ts in funding_timestamps {
storage.append("BTCUSDT", &make_funding_event(ts)).unwrap();
}
let loader = EnrichedDataLoader::new(DataSource::Binary {
storage_root: dir.clone(),
});
let bars: Vec<Bar> = (0..5).map(|i| make_bar(i * 2_000)).collect();
let history = loader
.load("BTCUSDT", bars, &[StreamKind::Funding])
.await
.unwrap();
assert_eq!(history.bar_count(), 5);
assert!(history.event_count() >= 5);
let timestamps: Vec<i64> = history.events.iter().map(|e| e.timestamp_ms()).collect();
for w in timestamps.windows(2) {
assert!(w[0] <= w[1], "events not sorted: {} > {}", w[0], w[1]);
}
}
#[tokio::test]
async fn json_read_filters_by_timestamp() {
let dir = tempdir("json_filter");
let symbol_dir = dir.join("BTCUSDT");
std::fs::create_dir_all(&symbol_dir).unwrap();
let events: Vec<TimedEvent> = [1000i64, 2000, 3000, 4000, 5000]
.iter()
.map(|&ts| make_funding_event(ts))
.collect();
let json = serde_json::to_string(&events).unwrap();
std::fs::write(symbol_dir.join("funding.json"), json).unwrap();
let loader = EnrichedDataLoader::new(DataSource::Json {
storage_root: dir.clone(),
});
let bars: Vec<Bar> = vec![make_bar(2000), make_bar(4000)];
let history = loader
.load("BTCUSDT", bars, &[StreamKind::Funding])
.await
.unwrap();
let funding_count = history
.events
.iter()
.filter(|e| matches!(e, TimedEvent::Funding(_)))
.count();
assert_eq!(funding_count, 3, "expected 3 funding events in [2000,4000]");
}
#[tokio::test]
async fn json_missing_file_returns_empty() {
let dir = tempdir("json_missing");
let loader = EnrichedDataLoader::new(DataSource::Json {
storage_root: dir.clone(),
});
let bars = vec![make_bar(1000)];
let history = loader
.load("BTCUSDT", bars, &[StreamKind::Funding])
.await
.unwrap();
assert_eq!(history.event_count(), 1);
}
#[tokio::test]
async fn rest_without_fetcher_returns_unsupported() {
let loader = EnrichedDataLoader::new(DataSource::Rest {
exchange: ExchangeId::Binance,
account_type: AccountType::FuturesCross,
});
let bars = vec![make_bar(1000)];
let err = loader
.load("BTCUSDT", bars, &[StreamKind::Funding])
.await
.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::Unsupported);
}
struct StaticFetcher(Vec<TimedEvent>);
#[async_trait]
impl RestFetcher for StaticFetcher {
async fn fetch(
&self,
_exchange: ExchangeId,
_account_type: AccountType,
_symbol: &str,
_kind: StreamKind,
from_ts: i64,
to_ts: i64,
) -> Result<Vec<TimedEvent>, String> {
Ok(self
.0
.iter()
.cloned()
.filter(|e| {
let ts = e.timestamp_ms();
ts >= from_ts && ts <= to_ts
})
.collect())
}
}
#[tokio::test]
async fn rest_with_fetcher_returns_events() {
let fetcher_events: Vec<TimedEvent> =
[500i64, 1500, 2500].iter().map(|&ts| make_funding_event(ts)).collect();
let fetcher = Arc::new(StaticFetcher(fetcher_events));
let loader = EnrichedDataLoader::new(DataSource::Rest {
exchange: ExchangeId::Binance,
account_type: AccountType::FuturesCross,
})
.with_rest_fetcher(fetcher);
let bars = vec![make_bar(1000), make_bar(2000)];
let history = loader
.load("BTCUSDT", bars, &[StreamKind::Funding])
.await
.unwrap();
let funding_count = history
.events
.iter()
.filter(|e| matches!(e, TimedEvent::Funding(_)))
.count();
assert_eq!(funding_count, 1, "only ts=1500 within [1000,2000]");
}
#[tokio::test]
async fn mixed_per_stream_routing() {
let binary_dir = tempdir("mixed_binary");
let json_dir = tempdir("mixed_json");
let storage = StorageRoot::new(&binary_dir);
for ts in [1000i64, 2000, 3000] {
storage.append("BTCUSDT", &make_funding_event(ts)).unwrap();
}
let oi_dir = json_dir.join("BTCUSDT");
std::fs::create_dir_all(&oi_dir).unwrap();
let oi_events: Vec<TimedEvent> =
[1500i64, 2500].iter().map(|&ts| make_oi_event(ts)).collect();
std::fs::write(oi_dir.join("open_interest.json"), serde_json::to_string(&oi_events).unwrap()).unwrap();
let mut per_stream: HashMap<StreamKind, Box<DataSource>> = HashMap::new();
per_stream.insert(
StreamKind::Funding,
Box::new(DataSource::Binary { storage_root: binary_dir }),
);
per_stream.insert(
StreamKind::OpenInterest,
Box::new(DataSource::Json { storage_root: json_dir }),
);
let loader = EnrichedDataLoader::new(DataSource::Mixed { per_stream });
let bars = vec![make_bar(1000), make_bar(2000), make_bar(3000)];
let history = loader
.load("BTCUSDT", bars, &[StreamKind::Funding, StreamKind::OpenInterest])
.await
.unwrap();
let funding_count = history
.events
.iter()
.filter(|e| matches!(e, TimedEvent::Funding(_)))
.count();
let oi_count = history
.events
.iter()
.filter(|e| matches!(e, TimedEvent::OpenInterest(_)))
.count();
assert_eq!(funding_count, 3);
assert_eq!(oi_count, 2);
}
#[tokio::test]
async fn mixed_missing_stream_returns_not_found() {
let per_stream: HashMap<StreamKind, Box<DataSource>> = HashMap::new();
let loader = EnrichedDataLoader::new(DataSource::Mixed { per_stream });
let bars = vec![make_bar(1000)];
let err = loader
.load("BTCUSDT", bars, &[StreamKind::Funding])
.await
.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::NotFound);
}
}