use crate::FixedPoint;
use crate::trade::Tick;
use std::path::Path;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum LoaderError {
#[error("File I/O error for {path}: {source}")]
Io {
path: String,
#[source]
source: std::io::Error,
},
#[error("CSV parse error at line {line} in {path}: {source}")]
CsvParse {
path: String,
line: usize,
#[source]
source: csv::Error,
},
#[error("Fixed-point conversion error at line {line} in {path}: {source}")]
FixedPoint {
path: String,
line: usize,
#[source]
source: crate::fixed_point::FixedPointError,
},
#[error("Record count mismatch in {path}: expected {expected}, got {actual}")]
CountMismatch {
path: String,
expected: usize,
actual: usize,
},
}
#[derive(Debug, serde::Deserialize)]
struct TickRecord {
a: i64,
p: String,
q: String,
f: i64,
l: i64,
#[serde(rename = "T")]
timestamp_ms: i64,
m: String,
}
impl TickRecord {
fn into_tick(self) -> Result<Tick, crate::fixed_point::FixedPointError> {
Ok(Tick {
ref_id: self.a,
price: FixedPoint::from_str(&self.p)?,
volume: FixedPoint::from_str(&self.q)?,
first_sub_id: self.f,
last_sub_id: self.l,
timestamp: self.timestamp_ms,
is_buyer_maker: self.m == "True",
is_best_match: None, best_bid: None,
best_ask: None,
})
}
}
pub fn load_btcusdt_test_data() -> Result<Vec<Tick>, LoaderError> {
let path = workspace_test_data_path("BTCUSDT/BTCUSDT_aggTrades_20250901.csv");
load_test_data(path, 5000)
}
pub fn load_ethusdt_test_data() -> Result<Vec<Tick>, LoaderError> {
let path = workspace_test_data_path("ETHUSDT/ETHUSDT_aggTrades_20250901.csv");
load_test_data(path, 10000)
}
fn workspace_test_data_path(relative_path: &str) -> std::path::PathBuf {
let manifest_dir = env!("CARGO_MANIFEST_DIR");
let workspace_root = std::path::Path::new(manifest_dir)
.parent() .unwrap()
.parent() .unwrap();
workspace_root.join("test_data").join(relative_path)
}
fn workspace_fixtures_path(relative_path: &str) -> std::path::PathBuf {
let manifest_dir = env!("CARGO_MANIFEST_DIR");
let workspace_root = std::path::Path::new(manifest_dir)
.parent() .unwrap()
.parent() .unwrap();
workspace_root
.join("tests")
.join("fixtures")
.join(relative_path)
}
#[derive(Debug, serde::Deserialize)]
struct TickRecordHeaderless {
ref_id: i64,
price: String,
quantity: String,
first_sub_id: i64,
last_sub_id: i64,
timestamp: i64,
is_buyer_maker: String,
is_best_match: String,
}
impl TickRecordHeaderless {
fn into_tick(self) -> Result<Tick, crate::fixed_point::FixedPointError> {
Ok(Tick {
ref_id: self.ref_id,
price: FixedPoint::from_str(&self.price)?,
volume: FixedPoint::from_str(&self.quantity)?,
first_sub_id: self.first_sub_id,
last_sub_id: self.last_sub_id,
timestamp: self.timestamp,
is_buyer_maker: self.is_buyer_maker == "True",
is_best_match: Some(self.is_best_match == "True"),
best_bid: None,
best_ask: None,
})
}
}
pub fn load_real_btcusdt_10k() -> Result<Vec<Tick>, LoaderError> {
let path = workspace_fixtures_path("BTCUSDT-aggTrades-sample-10k.csv");
load_headerless_data(path, 10001)
}
fn load_headerless_data<P: AsRef<Path>>(
path: P,
expected_count: usize,
) -> Result<Vec<Tick>, LoaderError> {
let path_str = path.as_ref().to_string_lossy().to_string();
let file = std::fs::File::open(&path).map_err(|e| LoaderError::Io {
path: path_str.clone(),
source: e,
})?;
let csv_buffer_size = (expected_count * 100).max(64 * 1024).min(2 * 1024 * 1024);
let mut reader = csv::ReaderBuilder::new()
.has_headers(false)
.buffer_capacity(csv_buffer_size)
.from_reader(file);
let mut trades = Vec::with_capacity(expected_count);
let mut line = 1;
for result in reader.deserialize() {
let record: TickRecordHeaderless = result.map_err(|e| LoaderError::CsvParse {
path: path_str.clone(),
line,
source: e,
})?;
let trade = record
.into_tick()
.map_err(|e| LoaderError::FixedPoint {
path: path_str.clone(),
line,
source: e,
})?;
trades.push(trade);
line += 1;
}
let actual_count = trades.len();
if actual_count != expected_count {
return Err(LoaderError::CountMismatch {
path: path_str,
expected: expected_count,
actual: actual_count,
});
}
Ok(trades)
}
fn load_test_data<P: AsRef<Path>>(
path: P,
expected_count: usize,
) -> Result<Vec<Tick>, LoaderError> {
let path_str = path.as_ref().to_string_lossy().to_string();
let file = std::fs::File::open(&path).map_err(|e| LoaderError::Io {
path: path_str.clone(),
source: e,
})?;
let csv_buffer_size = (expected_count * 100).max(64 * 1024).min(2 * 1024 * 1024);
let mut reader = csv::ReaderBuilder::new()
.has_headers(true)
.buffer_capacity(csv_buffer_size)
.from_reader(file);
let mut trades = Vec::with_capacity(expected_count);
let mut line = 2;
for result in reader.deserialize() {
let record: TickRecord = result.map_err(|e| LoaderError::CsvParse {
path: path_str.clone(),
line,
source: e,
})?;
let trade = record
.into_tick()
.map_err(|e| LoaderError::FixedPoint {
path: path_str.clone(),
line,
source: e,
})?;
trades.push(trade);
line += 1;
}
let actual_count = trades.len();
if actual_count != expected_count {
return Err(LoaderError::CountMismatch {
path: path_str,
expected: expected_count,
actual: actual_count,
});
}
Ok(trades)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_load_btcusdt_data() {
let trades = load_btcusdt_test_data().expect("Failed to load BTCUSDT test data");
assert_eq!(
trades.len(),
5000,
"BTCUSDT should have exactly 5000 trades"
);
let first = &trades[0];
assert_eq!(first.ref_id, 1);
assert_eq!(first.price.to_string(), "50014.00859087");
assert_eq!(first.volume.to_string(), "0.12019569");
assert_eq!(first.first_sub_id, 1);
assert_eq!(first.last_sub_id, 1);
assert_eq!(first.timestamp, 1756710002083);
assert!(!first.is_buyer_maker);
}
#[test]
fn test_load_ethusdt_data() {
let trades = load_ethusdt_test_data().expect("Failed to load ETHUSDT test data");
assert_eq!(
trades.len(),
10000,
"ETHUSDT should have exactly 10000 trades"
);
for trade in &trades {
assert!(trade.price.0 > 0, "Price must be positive");
assert!(trade.volume.0 > 0, "Volume must be positive");
assert!(trade.timestamp > 0, "Timestamp must be positive");
}
}
#[test]
fn test_temporal_integrity() {
let trades = load_btcusdt_test_data().unwrap();
for i in 1..trades.len() {
assert!(
trades[i].timestamp >= trades[i - 1].timestamp,
"Temporal integrity violation at trade {}: {} < {}",
i,
trades[i].timestamp,
trades[i - 1].timestamp
);
}
}
}