tesser_data/
io.rs

1use std::collections::HashSet;
2use std::fs::File;
3use std::path::{Path, PathBuf};
4use std::str::FromStr;
5
6use anyhow::{anyhow, Context, Result};
7use arrow::array::{Array, Decimal128Array, StringArray, TimestampNanosecondArray};
8use arrow::record_batch::RecordBatch;
9use chrono::{DateTime, NaiveDate, NaiveDateTime, Utc};
10use csv::{ReaderBuilder, WriterBuilder};
11use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
12use parquet::arrow::ArrowWriter;
13use parquet::basic::{Compression, ZstdLevel};
14use parquet::file::properties::WriterProperties;
15use rust_decimal::Decimal;
16
17use tesser_core::{Candle, Interval, Symbol, Tick};
18
19use crate::encoding::{candles_to_batch, ticks_to_batch};
20
21/// Canonical formats supported by `read_dataset`/`write_dataset`.
22#[derive(Clone, Copy, Debug, Eq, PartialEq)]
23pub enum DatasetFormat {
24    Csv,
25    Parquet,
26}
27
28impl DatasetFormat {
29    /// Attempt to infer the format based on the file extension.
30    #[must_use]
31    pub fn from_path(path: &Path) -> Self {
32        match path
33            .extension()
34            .and_then(|ext| ext.to_str())
35            .map(|ext| ext.to_ascii_lowercase())
36            .as_deref()
37        {
38            Some("parquet") => Self::Parquet,
39            _ => Self::Csv,
40        }
41    }
42}
43
44/// Fully materialized dataset containing normalized candles.
45pub struct CandleDataset {
46    pub format: DatasetFormat,
47    pub candles: Vec<Candle>,
48}
49
50/// Load a dataset from disk using either CSV or Parquet encodings.
51pub fn read_dataset(path: &Path) -> Result<CandleDataset> {
52    let format = DatasetFormat::from_path(path);
53    let candles = match format {
54        DatasetFormat::Csv => read_csv(path)
55            .with_context(|| format!("failed to load CSV dataset {}", path.display()))?,
56        DatasetFormat::Parquet => read_parquet(path)
57            .with_context(|| format!("failed to load parquet dataset {}", path.display()))?,
58    };
59    Ok(CandleDataset { format, candles })
60}
61
62/// Persist the provided candles to disk using the requested format.
63pub fn write_dataset(path: &Path, format: DatasetFormat, candles: &[Candle]) -> Result<()> {
64    match format {
65        DatasetFormat::Csv => write_csv(path, candles),
66        DatasetFormat::Parquet => write_parquet(path, candles),
67    }
68}
69
70/// Helper for normalizing and persisting tick data to parquet files.
71pub struct TicksWriter {
72    path: PathBuf,
73    rows: Vec<Tick>,
74    seen_ids: HashSet<String>,
75}
76
77impl TicksWriter {
78    /// Build a writer bound to the provided destination path.
79    pub fn new(path: impl Into<PathBuf>) -> Self {
80        Self {
81            path: path.into(),
82            rows: Vec::new(),
83            seen_ids: HashSet::new(),
84        }
85    }
86
87    /// Current number of buffered ticks (before final dedupe).
88    pub fn len(&self) -> usize {
89        self.rows.len()
90    }
91
92    /// Returns true when no ticks have been appended.
93    pub fn is_empty(&self) -> bool {
94        self.rows.is_empty()
95    }
96
97    /// Append a single tick, skipping duplicates that reuse the same trade identifier.
98    pub fn push(&mut self, trade_id: Option<String>, tick: Tick) {
99        if let Some(id) = trade_id.filter(|value| !value.is_empty()) {
100            if !self.seen_ids.insert(id) {
101                return;
102            }
103        }
104        self.rows.push(tick);
105    }
106
107    /// Extend the writer with one or more `(trade_id, tick)` tuples.
108    pub fn extend<I>(&mut self, trades: I)
109    where
110        I: IntoIterator<Item = (Option<String>, Tick)>,
111    {
112        for (trade_id, tick) in trades {
113            self.push(trade_id, tick);
114        }
115    }
116
117    /// Finalize the parquet file once all ticks have been queued.
118    pub fn finish(mut self) -> Result<()> {
119        self.rows
120            .sort_by_key(|tick| tick.exchange_timestamp.timestamp_millis());
121        self.rows.dedup_by(|a, b| {
122            a.exchange_timestamp == b.exchange_timestamp
123                && a.price == b.price
124                && a.size == b.size
125                && a.side == b.side
126        });
127        write_ticks_parquet(&self.path, &self.rows)
128    }
129}
130
131fn read_csv(path: &Path) -> Result<Vec<Candle>> {
132    let mut reader = ReaderBuilder::new()
133        .flexible(true)
134        .from_path(path)
135        .with_context(|| format!("failed to open {}", path.display()))?;
136    let mut candles = Vec::new();
137    for row in reader.records() {
138        let record = row.with_context(|| format!("invalid row in {}", path.display()))?;
139        let timestamp = parse_timestamp(
140            record
141                .get(1)
142                .ok_or_else(|| anyhow!("missing timestamp column in {}", path.display()))?,
143        )?;
144        let symbol = match record.get(0) {
145            Some(value) if !value.trim().is_empty() => value.to_string(),
146            _ => infer_symbol(path).ok_or_else(|| {
147                anyhow!(
148                    "missing symbol column and unable to infer from {}",
149                    path.display()
150                )
151            })?,
152        };
153        let candle = Candle {
154            symbol: Symbol::from(symbol.as_str()),
155            interval: infer_interval(path).unwrap_or(Interval::OneMinute),
156            open: parse_decimal(record.get(2), "open", path)?,
157            high: parse_decimal(record.get(3), "high", path)?,
158            low: parse_decimal(record.get(4), "low", path)?,
159            close: parse_decimal(record.get(5), "close", path)?,
160            volume: parse_decimal(record.get(6), "volume", path)?,
161            timestamp,
162        };
163        candles.push(candle);
164    }
165    candles.sort_by_key(|c| c.timestamp);
166    Ok(candles)
167}
168
169fn read_parquet(path: &Path) -> Result<Vec<Candle>> {
170    let file = File::open(path)
171        .with_context(|| format!("failed to open parquet file {}", path.display()))?;
172    let reader = ParquetRecordBatchReaderBuilder::try_new(file)?
173        .with_batch_size(1024)
174        .build()?;
175    let mut columns: Option<CandleColumns> = None;
176    let mut candles = Vec::new();
177    for batch in reader {
178        let batch = batch?;
179        if columns.is_none() {
180            columns = Some(CandleColumns::from_batch(&batch)?);
181        }
182        let column_mapping = columns.as_ref().unwrap();
183        for row in 0..batch.num_rows() {
184            candles.push(column_mapping.decode(&batch, row)?);
185        }
186    }
187    candles.sort_by_key(|c| c.timestamp);
188    Ok(candles)
189}
190
191fn write_csv(path: &Path, candles: &[Candle]) -> Result<()> {
192    if let Some(parent) = path.parent() {
193        std::fs::create_dir_all(parent)
194            .with_context(|| format!("failed to create directory {}", parent.display()))?;
195    }
196    let mut writer = WriterBuilder::new()
197        .has_headers(true)
198        .from_path(path)
199        .with_context(|| format!("failed to create {}", path.display()))?;
200    writer.write_record([
201        "symbol",
202        "timestamp",
203        "open",
204        "high",
205        "low",
206        "close",
207        "volume",
208    ])?;
209    for candle in candles {
210        writer.write_record([
211            candle.symbol.code(),
212            &candle.timestamp.to_rfc3339(),
213            &candle.open.to_string(),
214            &candle.high.to_string(),
215            &candle.low.to_string(),
216            &candle.close.to_string(),
217            &candle.volume.to_string(),
218        ])?;
219    }
220    writer.flush()?;
221    Ok(())
222}
223
224fn write_parquet(path: &Path, candles: &[Candle]) -> Result<()> {
225    if let Some(parent) = path.parent() {
226        std::fs::create_dir_all(parent)
227            .with_context(|| format!("failed to create directory {}", parent.display()))?;
228    }
229    let batch = candles_to_batch(candles)?;
230    let file =
231        File::create(path).with_context(|| format!("failed to create {}", path.display()))?;
232    let props = WriterProperties::builder()
233        .set_compression(Compression::ZSTD(ZstdLevel::default()))
234        .build();
235    let mut writer = ArrowWriter::try_new(file, batch.schema(), Some(props))?;
236    writer.write(&batch)?;
237    writer.close()?;
238    Ok(())
239}
240
241fn write_ticks_parquet(path: &Path, ticks: &[Tick]) -> Result<()> {
242    if let Some(parent) = path.parent() {
243        std::fs::create_dir_all(parent)
244            .with_context(|| format!("failed to create directory {}", parent.display()))?;
245    }
246    let batch = ticks_to_batch(ticks)?;
247    let file =
248        File::create(path).with_context(|| format!("failed to create {}", path.display()))?;
249    let props = WriterProperties::builder()
250        .set_compression(Compression::ZSTD(ZstdLevel::default()))
251        .build();
252    let mut writer = ArrowWriter::try_new(file, batch.schema(), Some(props))?;
253    writer.write(&batch)?;
254    writer.close()?;
255    Ok(())
256}
257
258fn parse_decimal(value: Option<&str>, column: &str, path: &Path) -> Result<Decimal> {
259    let text = value.ok_or_else(|| anyhow!("missing {column} column in {}", path.display()))?;
260    Decimal::from_str(text)
261        .with_context(|| format!("invalid {column} value '{text}' in {}", path.display()))
262}
263
264fn parse_timestamp(value: &str) -> Result<DateTime<Utc>> {
265    if let Ok(ts) = DateTime::parse_from_rfc3339(value) {
266        return Ok(ts.with_timezone(&Utc));
267    }
268    if let Ok(dt) = NaiveDateTime::parse_from_str(value, "%Y-%m-%d %H:%M:%S") {
269        return Ok(DateTime::<Utc>::from_naive_utc_and_offset(dt, Utc));
270    }
271    if let Ok(date) = NaiveDate::parse_from_str(value, "%Y-%m-%d") {
272        let dt = date
273            .and_hms_opt(0, 0, 0)
274            .ok_or_else(|| anyhow!("invalid date '{value}'"))?;
275        return Ok(DateTime::<Utc>::from_naive_utc_and_offset(dt, Utc));
276    }
277    Err(anyhow!("unable to parse timestamp '{value}'"))
278}
279
280fn infer_symbol(path: &Path) -> Option<String> {
281    path.parent()
282        .and_then(|parent| parent.file_name())
283        .map(|os| os.to_string_lossy().to_string())
284}
285
286fn infer_interval(path: &Path) -> Option<Interval> {
287    path.file_stem()
288        .and_then(|stem| stem.to_str())
289        .and_then(|stem| stem.split('_').next())
290        .and_then(|token| Interval::from_str(token).ok())
291}
292
293struct CandleColumns {
294    symbol: usize,
295    interval: usize,
296    open: usize,
297    high: usize,
298    low: usize,
299    close: usize,
300    volume: usize,
301    timestamp: usize,
302}
303
304impl CandleColumns {
305    fn from_batch(batch: &RecordBatch) -> Result<Self> {
306        Ok(Self {
307            symbol: column_index(batch, "symbol")?,
308            interval: column_index(batch, "interval")?,
309            open: column_index(batch, "open")?,
310            high: column_index(batch, "high")?,
311            low: column_index(batch, "low")?,
312            close: column_index(batch, "close")?,
313            volume: column_index(batch, "volume")?,
314            timestamp: column_index(batch, "timestamp")?,
315        })
316    }
317
318    fn decode(&self, batch: &RecordBatch, row: usize) -> Result<Candle> {
319        let symbol = string_value(batch, self.symbol, row)?;
320        let interval_raw = string_value(batch, self.interval, row)?;
321        let interval =
322            Interval::from_str(&interval_raw).map_err(|err| anyhow!("{interval_raw}: {err}"))?;
323        Ok(Candle {
324            symbol: Symbol::from(symbol.as_str()),
325            interval,
326            open: decimal_value(batch, self.open, row)?,
327            high: decimal_value(batch, self.high, row)?,
328            low: decimal_value(batch, self.low, row)?,
329            close: decimal_value(batch, self.close, row)?,
330            volume: decimal_value(batch, self.volume, row)?,
331            timestamp: timestamp_value(batch, self.timestamp, row)?,
332        })
333    }
334}
335
336fn column_index(batch: &RecordBatch, name: &str) -> Result<usize> {
337    batch
338        .schema()
339        .column_with_name(name)
340        .map(|(idx, _)| idx)
341        .ok_or_else(|| anyhow!("column '{name}' missing from schema"))
342}
343
344fn string_value(batch: &RecordBatch, column: usize, row: usize) -> Result<String> {
345    let array = batch
346        .column(column)
347        .as_any()
348        .downcast_ref::<StringArray>()
349        .ok_or_else(|| anyhow!("column {column} is not Utf8"))?;
350    if array.is_null(row) {
351        return Err(anyhow!("column {column} contains null string"));
352    }
353    Ok(array.value(row).to_string())
354}
355
356fn decimal_value(batch: &RecordBatch, column: usize, row: usize) -> Result<Decimal> {
357    let array = batch
358        .column(column)
359        .as_any()
360        .downcast_ref::<Decimal128Array>()
361        .ok_or_else(|| anyhow!("column {column} is not decimal"))?;
362    if array.is_null(row) {
363        return Err(anyhow!("column {column} contains null decimal"));
364    }
365    Ok(Decimal::from_i128_with_scale(
366        array.value(row),
367        array.scale() as u32,
368    ))
369}
370
371fn timestamp_value(batch: &RecordBatch, column: usize, row: usize) -> Result<DateTime<Utc>> {
372    let array = batch
373        .column(column)
374        .as_any()
375        .downcast_ref::<TimestampNanosecondArray>()
376        .ok_or_else(|| anyhow!("column {column} is not timestamp"))?;
377    if array.is_null(row) {
378        return Err(anyhow!("column {column} contains null timestamp"));
379    }
380    let nanos = array.value(row);
381    let secs = nanos.div_euclid(1_000_000_000);
382    let sub = nanos.rem_euclid(1_000_000_000) as u32;
383    DateTime::<Utc>::from_timestamp(secs, sub)
384        .ok_or_else(|| anyhow!("timestamp overflow for value {nanos}"))
385}
386
387#[cfg(test)]
388mod tests {
389    use std::fs::File;
390
391    use chrono::{Duration, TimeZone, Utc};
392    use rust_decimal::{prelude::FromPrimitive, Decimal};
393    use tempfile::tempdir;
394    use tesser_core::{Side, Symbol};
395
396    use super::*;
397
398    fn sample_candles() -> Vec<Candle> {
399        let base = Utc::now() - Duration::minutes(10);
400        (0..4)
401            .map(|idx| Candle {
402                symbol: Symbol::from("BTCUSDT"),
403                interval: Interval::OneMinute,
404                open: Decimal::new(10 + idx as i64, 0),
405                high: Decimal::new(11 + idx as i64, 0),
406                low: Decimal::new(9 + idx as i64, 0),
407                close: Decimal::new(10 + idx as i64, 0),
408                volume: Decimal::new(1, 0),
409                timestamp: base + Duration::minutes(idx as i64),
410            })
411            .collect()
412    }
413
414    #[test]
415    fn round_trip_csv() -> Result<()> {
416        let temp = tempdir()?;
417        let path = temp.path().join("1m_BTCUSDT.csv");
418        let candles = sample_candles();
419        write_dataset(&path, DatasetFormat::Csv, &candles)?;
420        let dataset = read_dataset(&path)?;
421        assert_eq!(dataset.candles.len(), candles.len());
422        Ok(())
423    }
424
425    #[test]
426    fn round_trip_parquet() -> Result<()> {
427        let temp = tempdir()?;
428        let path = temp.path().join("1m_BTCUSDT.parquet");
429        let candles = sample_candles();
430        write_dataset(&path, DatasetFormat::Parquet, &candles)?;
431        let dataset = read_dataset(&path)?;
432        assert_eq!(dataset.candles.len(), candles.len());
433        Ok(())
434    }
435
436    #[test]
437    fn ticks_writer_dedupes_trade_ids_and_payloads() -> Result<()> {
438        let temp = tempdir()?;
439        let path = temp.path().join("ticks.parquet");
440        let mut writer = TicksWriter::new(&path);
441        writer.push(
442            Some("trade-1".to_string()),
443            sample_tick(1_000, 100.0, 1.0, Side::Buy),
444        );
445        // Duplicate ID should be skipped.
446        writer.push(
447            Some("trade-1".to_string()),
448            sample_tick(1_000, 100.0, 1.0, Side::Buy),
449        );
450        // Same payload without ID should dedupe during finish.
451        writer.extend([
452            (None, sample_tick(2_000, 101.0, 2.0, Side::Sell)),
453            (None, sample_tick(2_000, 101.0, 2.0, Side::Sell)),
454        ]);
455        writer.finish()?;
456
457        let file = File::open(&path)?;
458        let reader = ParquetRecordBatchReaderBuilder::try_new(file)?
459            .with_batch_size(8)
460            .build()?;
461        let mut rows = 0;
462        for batch in reader {
463            rows += batch?.num_rows();
464        }
465        assert_eq!(rows, 2);
466        Ok(())
467    }
468
469    fn sample_tick(ts_ms: i64, price: f64, size: f64, side: Side) -> Tick {
470        let price = Decimal::from_f64(price).expect("valid price");
471        let size = Decimal::from_f64(size).expect("valid size");
472        let timestamp = Utc
473            .timestamp_millis_opt(ts_ms)
474            .single()
475            .expect("valid timestamp");
476        Tick {
477            symbol: Symbol::from("BTCUSDT"),
478            price,
479            size,
480            side,
481            exchange_timestamp: timestamp,
482            received_at: timestamp,
483        }
484    }
485}