tesser_data/etl/
mod.rs

1use std::collections::BTreeMap;
2use std::fs::{self, File};
3use std::io::BufReader;
4use std::path::Path;
5use std::str::FromStr;
6use std::sync::Arc;
7
8use anyhow::{anyhow, bail, Context, Result};
9use arrow::array::{ArrayRef, Decimal128Builder, Int64Builder, StringBuilder};
10use arrow::datatypes::{DataType, SchemaRef};
11use arrow::record_batch::RecordBatch;
12use chrono::{DateTime, Datelike, Utc};
13use csv::StringRecord;
14use glob::glob;
15use parquet::arrow::ArrowWriter;
16use rust_decimal::prelude::RoundingStrategy;
17use rust_decimal::Decimal;
18use serde::Deserialize;
19use tracing::info;
20
21use crate::schema::{
22    canonical_candle_schema, CANONICAL_DECIMAL_PRECISION, CANONICAL_DECIMAL_SCALE,
23    CANONICAL_DECIMAL_SCALE_U32,
24};
25
26/// Strategy that controls how normalized candles are partitioned.
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum Partitioning {
29    Daily,
30    Monthly,
31}
32
33/// Declarative mapping describing how a raw CSV file should be parsed.
34#[derive(Debug, Clone, Deserialize)]
35pub struct MappingConfig {
36    pub csv: CsvConfig,
37    pub fields: FieldMapping,
38    #[serde(default = "MappingConfig::default_interval")]
39    pub interval: String,
40}
41
42impl MappingConfig {
43    fn default_interval() -> String {
44        "1m".to_string()
45    }
46}
47
48#[derive(Debug, Clone, Deserialize)]
49pub struct CsvConfig {
50    #[serde(default)]
51    delimiter: Option<String>,
52    #[serde(default = "CsvConfig::default_has_header")]
53    has_header: bool,
54}
55
56impl CsvConfig {
57    fn delimiter(&self) -> u8 {
58        self.delimiter
59            .as_deref()
60            .and_then(|value| value.as_bytes().first().copied())
61            .unwrap_or(b',')
62    }
63
64    fn default_has_header() -> bool {
65        true
66    }
67
68    fn has_header(&self) -> bool {
69        self.has_header
70    }
71}
72
73impl Default for CsvConfig {
74    fn default() -> Self {
75        Self {
76            delimiter: None,
77            has_header: true,
78        }
79    }
80}
81
82#[derive(Debug, Clone, Deserialize)]
83pub struct FieldMapping {
84    pub timestamp: TimestampField,
85    pub open: ValueField,
86    pub high: ValueField,
87    pub low: ValueField,
88    pub close: ValueField,
89    #[serde(default)]
90    pub volume: Option<ValueField>,
91}
92
93#[derive(Debug, Clone, Deserialize)]
94pub struct TimestampField {
95    pub col: usize,
96    #[serde(default)]
97    pub unit: TimestampUnit,
98    #[serde(default)]
99    pub format: TimestampFormat,
100}
101
102impl TimestampField {
103    fn parse(&self, record: &StringRecord) -> Result<i64> {
104        let raw = record
105            .get(self.col)
106            .ok_or_else(|| anyhow!("row missing timestamp column {}", self.col))?
107            .trim();
108        if raw.is_empty() {
109            bail!("timestamp column {} is empty", self.col);
110        }
111        match self.format {
112            TimestampFormat::Unix => self.parse_unix(raw),
113            TimestampFormat::Rfc3339 => Self::parse_rfc3339(raw),
114        }
115    }
116
117    fn parse_unix(&self, raw: &str) -> Result<i64> {
118        let value: i64 = raw
119            .parse()
120            .with_context(|| format!("invalid timestamp '{raw}'"))?;
121        let nanos = value
122            .checked_mul(self.unit.multiplier())
123            .ok_or_else(|| anyhow!("timestamp overflow for value {value}"))?;
124        Ok(nanos)
125    }
126
127    fn parse_rfc3339(raw: &str) -> Result<i64> {
128        let dt = DateTime::parse_from_rfc3339(raw)
129            .with_context(|| format!("invalid RFC3339 timestamp '{raw}'"))?
130            .with_timezone(&Utc);
131        dt.timestamp_nanos_opt()
132            .ok_or_else(|| anyhow!("timestamp overflow for value {raw}"))
133    }
134}
135
136#[derive(Debug, Clone, Copy, Deserialize)]
137#[serde(rename_all = "lowercase")]
138#[derive(Default)]
139pub enum TimestampUnit {
140    Seconds,
141    #[default]
142    Milliseconds,
143    Microseconds,
144    Nanoseconds,
145}
146
147impl TimestampUnit {
148    fn multiplier(self) -> i64 {
149        match self {
150            TimestampUnit::Seconds => 1_000_000_000,
151            TimestampUnit::Milliseconds => 1_000_000,
152            TimestampUnit::Microseconds => 1_000,
153            TimestampUnit::Nanoseconds => 1,
154        }
155    }
156}
157
158#[derive(Debug, Clone, Copy, Deserialize)]
159#[serde(rename_all = "lowercase")]
160#[derive(Default)]
161pub enum TimestampFormat {
162    #[default]
163    Unix,
164    Rfc3339,
165}
166
167#[derive(Debug, Clone, Deserialize)]
168pub struct ValueField {
169    pub col: usize,
170}
171
172impl ValueField {
173    fn parse_decimal(&self, record: &StringRecord, label: &str) -> Result<Decimal> {
174        let raw = record
175            .get(self.col)
176            .ok_or_else(|| anyhow!("row missing {label} column {}", self.col))?
177            .trim();
178        if raw.is_empty() {
179            bail!("{label} column {} is empty", self.col);
180        }
181        Decimal::from_str(raw).map_err(|err| anyhow!("invalid {} value '{}': {err}", label, raw))
182    }
183}
184
185/// ETL pipeline that converts arbitrary CSVs into the canonical Arrow schema.
186pub struct Pipeline {
187    mapping: MappingConfig,
188}
189
190impl Pipeline {
191    pub fn new(mapping: MappingConfig) -> Self {
192        Self { mapping }
193    }
194
195    pub fn run(
196        &self,
197        pattern: &str,
198        output: &Path,
199        symbol: &str,
200        partitioning: Partitioning,
201    ) -> Result<usize> {
202        let mut total_rows = 0usize;
203        let mut matched = false;
204        for entry in glob(pattern).with_context(|| format!("invalid source glob {pattern}"))? {
205            let path = entry?;
206            matched = true;
207            let rows = self.load_rows(&path, symbol)?;
208            if rows.is_empty() {
209                continue;
210            }
211            let count = rows.len();
212            self.write_partitions(rows, output, partitioning)?;
213            total_rows += count;
214            info!(path = %path.display(), rows = count, "normalized source file");
215        }
216        if !matched {
217            bail!("no files matched pattern {pattern}");
218        }
219        Ok(total_rows)
220    }
221
222    fn load_rows(&self, path: &Path, symbol: &str) -> Result<Vec<CanonicalCandle>> {
223        let interval_label = self.mapping.interval.clone();
224        let file = File::open(path)
225            .with_context(|| format!("failed to open source file {}", path.display()))?;
226        let mut reader = csv::ReaderBuilder::new()
227            .delimiter(self.mapping.csv.delimiter())
228            .has_headers(self.mapping.csv.has_header())
229            .from_reader(BufReader::new(file));
230        let mut rows = Vec::new();
231        for (idx, record) in reader.records().enumerate() {
232            let record = record.with_context(|| format!("failed to read record {}", idx + 1))?;
233            let timestamp = self
234                .mapping
235                .fields
236                .timestamp
237                .parse(&record)
238                .with_context(|| format!("invalid timestamp in {}", path.display()))?;
239            let open = self
240                .mapping
241                .fields
242                .open
243                .parse_decimal(&record, "open")
244                .with_context(|| format!("invalid open price in {}", path.display()))?;
245            let high = self
246                .mapping
247                .fields
248                .high
249                .parse_decimal(&record, "high")
250                .with_context(|| format!("invalid high price in {}", path.display()))?;
251            let low = self
252                .mapping
253                .fields
254                .low
255                .parse_decimal(&record, "low")
256                .with_context(|| format!("invalid low price in {}", path.display()))?;
257            let close = self
258                .mapping
259                .fields
260                .close
261                .parse_decimal(&record, "close")
262                .with_context(|| format!("invalid close price in {}", path.display()))?;
263            if high < low {
264                bail!(
265                    "row {} failed validation: high {} < low {}",
266                    idx + 1,
267                    high,
268                    low
269                );
270            }
271            let volume = if let Some(field) = &self.mapping.fields.volume {
272                let parsed = field
273                    .parse_decimal(&record, "volume")
274                    .with_context(|| format!("invalid volume in {}", path.display()))?;
275                if parsed < Decimal::ZERO {
276                    bail!(
277                        "row {} failed validation: negative volume {}",
278                        idx + 1,
279                        parsed
280                    );
281                }
282                Some(parsed)
283            } else {
284                None
285            };
286            rows.push(CanonicalCandle {
287                timestamp,
288                symbol: symbol.to_string(),
289                interval: interval_label.clone(),
290                open,
291                high,
292                low,
293                close,
294                volume,
295            });
296        }
297        Ok(rows)
298    }
299
300    fn write_partitions(
301        &self,
302        rows: Vec<CanonicalCandle>,
303        output: &Path,
304        partitioning: Partitioning,
305    ) -> Result<()> {
306        let schema = canonical_candle_schema();
307        let mut partitions: BTreeMap<String, Vec<CanonicalCandle>> = BTreeMap::new();
308        for row in rows {
309            let key = partition_path(&row.symbol, &row.interval, row.timestamp, partitioning)?;
310            partitions.entry(key).or_default().push(row);
311        }
312        for (counter, (relative, records)) in partitions.into_iter().enumerate() {
313            let dir = output.join(&relative);
314            fs::create_dir_all(&dir)
315                .with_context(|| format!("failed to create {}", dir.display()))?;
316            let file_path = dir.join(format!("part-{counter:05}.parquet"));
317            let batch = rows_to_batch(&records, &schema)?;
318            let file = File::create(&file_path)
319                .with_context(|| format!("failed to create {}", file_path.display()))?;
320            let mut writer = ArrowWriter::try_new(file, schema.clone(), None)?;
321            writer.write(&batch)?;
322            writer.close()?;
323        }
324        Ok(())
325    }
326}
327
328#[derive(Clone)]
329struct CanonicalCandle {
330    timestamp: i64,
331    symbol: String,
332    interval: String,
333    open: Decimal,
334    high: Decimal,
335    low: Decimal,
336    close: Decimal,
337    volume: Option<Decimal>,
338}
339
340fn rows_to_batch(rows: &[CanonicalCandle], schema: &SchemaRef) -> Result<RecordBatch> {
341    let decimal_type = DataType::Decimal128(CANONICAL_DECIMAL_PRECISION, CANONICAL_DECIMAL_SCALE);
342    let mut timestamps = Int64Builder::new();
343    let mut symbols = StringBuilder::new();
344    let mut intervals = StringBuilder::new();
345    let mut open_builder = Decimal128Builder::new().with_data_type(decimal_type.clone());
346    let mut high_builder = Decimal128Builder::new().with_data_type(decimal_type.clone());
347    let mut low_builder = Decimal128Builder::new().with_data_type(decimal_type.clone());
348    let mut close_builder = Decimal128Builder::new().with_data_type(decimal_type.clone());
349    let mut volume_builder = Decimal128Builder::new().with_data_type(decimal_type.clone());
350
351    for row in rows {
352        timestamps.append_value(row.timestamp);
353        symbols.append_value(&row.symbol);
354        intervals.append_value(&row.interval);
355        open_builder.append_value(decimal_to_i128(row.open)?);
356        high_builder.append_value(decimal_to_i128(row.high)?);
357        low_builder.append_value(decimal_to_i128(row.low)?);
358        close_builder.append_value(decimal_to_i128(row.close)?);
359        if let Some(volume) = row.volume {
360            volume_builder.append_value(decimal_to_i128(volume)?);
361        } else {
362            volume_builder.append_null();
363        }
364    }
365
366    let columns: Vec<ArrayRef> = vec![
367        Arc::new(timestamps.finish()),
368        Arc::new(symbols.finish()),
369        Arc::new(intervals.finish()),
370        Arc::new(open_builder.finish()),
371        Arc::new(high_builder.finish()),
372        Arc::new(low_builder.finish()),
373        Arc::new(close_builder.finish()),
374        Arc::new(volume_builder.finish()),
375    ];
376
377    RecordBatch::try_new(schema.clone(), columns).map_err(Into::into)
378}
379
380fn decimal_to_i128(value: Decimal) -> Result<i128> {
381    let mut normalized = value;
382    if normalized.scale() > CANONICAL_DECIMAL_SCALE_U32 {
383        normalized = normalized.round_dp_with_strategy(
384            CANONICAL_DECIMAL_SCALE_U32,
385            RoundingStrategy::MidpointNearestEven,
386        );
387    }
388    let scale = normalized.scale();
389    if scale > CANONICAL_DECIMAL_SCALE_U32 {
390        bail!(
391            "value scale {} exceeds canonical precision {CANONICAL_DECIMAL_SCALE_U32}",
392            scale
393        );
394    }
395    let diff = CANONICAL_DECIMAL_SCALE_U32 - scale;
396    let factor = 10i128
397        .checked_pow(diff)
398        .ok_or_else(|| anyhow!("decimal scaling factor overflow"))?;
399    normalized
400        .mantissa()
401        .checked_mul(factor)
402        .ok_or_else(|| anyhow!("decimal mantissa overflow"))
403}
404
405fn partition_path(
406    symbol: &str,
407    interval: &str,
408    timestamp: i64,
409    partitioning: Partitioning,
410) -> Result<String> {
411    let dt = datetime_from_ns(timestamp)?;
412    let mut segments = vec![
413        format!("symbol={}", sanitize_symbol(symbol)),
414        format!("interval={}", sanitize_interval(interval)),
415    ];
416    segments.push(format!("year={:04}", dt.year()));
417    segments.push(format!("month={:02}", dt.month()));
418    if matches!(partitioning, Partitioning::Daily) {
419        segments.push(format!("day={:02}", dt.day()));
420    }
421    Ok(segments.join("/"))
422}
423
424fn sanitize_symbol(symbol: &str) -> String {
425    symbol.replace(':', "_")
426}
427
428fn sanitize_interval(interval: &str) -> String {
429    interval
430        .chars()
431        .map(|c| if c.is_ascii_alphanumeric() { c } else { '_' })
432        .collect()
433}
434
435fn datetime_from_ns(timestamp: i64) -> Result<DateTime<Utc>> {
436    let seconds = timestamp / 1_000_000_000;
437    let nanos = (timestamp % 1_000_000_000).unsigned_abs() as u32;
438    DateTime::from_timestamp(seconds, nanos)
439        .ok_or_else(|| anyhow!("failed to convert timestamp {} to datetime", timestamp))
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445    use std::fs;
446    use tempfile::tempdir;
447
448    #[test]
449    fn pipeline_normalizes_csv() {
450        let dir = tempdir().unwrap();
451        let src = dir.path().join("candles.csv");
452        fs::write(
453            &src,
454            "ts,open,high,low,close,vol
4551700000000000,100,110,90,105,12
4561700000060000,105,115,95,100,15
457",
458        )
459        .unwrap();
460        let mapping = MappingConfig {
461            csv: CsvConfig::default(),
462            fields: FieldMapping {
463                timestamp: TimestampField {
464                    col: 0,
465                    unit: TimestampUnit::Milliseconds,
466                    format: TimestampFormat::Unix,
467                },
468                open: ValueField { col: 1 },
469                high: ValueField { col: 2 },
470                low: ValueField { col: 3 },
471                close: ValueField { col: 4 },
472                volume: Some(ValueField { col: 5 }),
473            },
474            interval: "1m".into(),
475        };
476        let pipeline = Pipeline::new(mapping);
477        let output = dir.path().join("lake");
478        let rows = pipeline
479            .run(
480                src.to_str().unwrap(),
481                &output,
482                "binance:BTCUSDT",
483                Partitioning::Daily,
484            )
485            .unwrap();
486        assert_eq!(rows, 2);
487        assert!(count_files(&output) > 0, "no parquet files written");
488    }
489
490    #[test]
491    fn pipeline_parses_rfc3339_timestamps() {
492        let dir = tempdir().unwrap();
493        let src = dir.path().join("candles.csv");
494        fs::write(
495            &src,
496            "ts,open,high,low,close,vol\n2024-01-01T00:00:00Z,100,110,90,105,12\n",
497        )
498        .unwrap();
499        let mapping = MappingConfig {
500            csv: CsvConfig::default(),
501            fields: FieldMapping {
502                timestamp: TimestampField {
503                    col: 0,
504                    unit: TimestampUnit::Milliseconds,
505                    format: TimestampFormat::Rfc3339,
506                },
507                open: ValueField { col: 1 },
508                high: ValueField { col: 2 },
509                low: ValueField { col: 3 },
510                close: ValueField { col: 4 },
511                volume: Some(ValueField { col: 5 }),
512            },
513            interval: "1m".into(),
514        };
515        let pipeline = Pipeline::new(mapping);
516        let output = dir.path().join("lake");
517        let rows = pipeline
518            .run(
519                src.to_str().unwrap(),
520                &output,
521                "binance:BTCUSDT",
522                Partitioning::Daily,
523            )
524            .unwrap();
525        assert_eq!(rows, 1);
526        assert!(count_files(&output) > 0);
527    }
528
529    fn count_files(root: &Path) -> usize {
530        fn visit(dir: &Path, total: &mut usize) {
531            if let Ok(entries) = fs::read_dir(dir) {
532                for entry in entries.flatten() {
533                    let path = entry.path();
534                    if path.is_dir() {
535                        visit(&path, total);
536                    } else {
537                        *total += 1;
538                    }
539                }
540            }
541        }
542        let mut total = 0;
543        if root.exists() {
544            visit(root, &mut total);
545        }
546        total
547    }
548}