Skip to main content

data_preprocess/
db.rs

1use std::path::Path;
2
3use chrono::NaiveDateTime;
4use duckdb::{Connection, params};
5
6use crate::error::{DataError, Result};
7use crate::models::*;
8
9// DuckDB doesn't implement ToSql/FromSql for chrono types,
10// so we serialize timestamps as strings in ISO format.
11const TS_FMT: &str = "%Y-%m-%d %H:%M:%S%.f";
12const TS_FMT_NO_FRAC: &str = "%Y-%m-%d %H:%M:%S";
13
14fn ndt_to_string(ndt: &NaiveDateTime) -> String {
15    ndt.format(TS_FMT).to_string()
16}
17
18fn string_to_ndt(s: &str) -> Result<NaiveDateTime> {
19    NaiveDateTime::parse_from_str(s, TS_FMT)
20        .or_else(|_| NaiveDateTime::parse_from_str(s, TS_FMT_NO_FRAC))
21        .map_err(|e| DataError::InvalidTimestamp(format!("{s}: {e}")))
22}
23
24/// DuckDB-backed storage for ticks and bars.
25pub struct Database {
26    conn: Connection,
27    db_path: Option<String>,
28}
29
30/// Query parameters for tick view commands.
31pub struct QueryOpts {
32    pub exchange: String,
33    pub symbol: String,
34    pub from: Option<NaiveDateTime>,
35    pub to: Option<NaiveDateTime>,
36    pub limit: usize,
37    pub tail: bool,
38    pub descending: bool,
39}
40
41/// Query parameters for bar view commands.
42pub struct BarQueryOpts {
43    pub exchange: String,
44    pub symbol: String,
45    pub timeframe: String,
46    pub from: Option<NaiveDateTime>,
47    pub to: Option<NaiveDateTime>,
48    pub limit: usize,
49    pub tail: bool,
50    pub descending: bool,
51}
52
53impl Database {
54    /// Open (or create) a DuckDB database at the given path.
55    pub fn open(path: &Path) -> Result<Self> {
56        let conn = Connection::open(path)?;
57        let db = Self {
58            conn,
59            db_path: Some(path.display().to_string()),
60        };
61        db.init_schema()?;
62        Ok(db)
63    }
64
65    /// Open an in-memory database (for testing).
66    pub fn open_in_memory() -> Result<Self> {
67        let conn = Connection::open_in_memory()?;
68        let db = Self {
69            conn,
70            db_path: None,
71        };
72        db.init_schema()?;
73        Ok(db)
74    }
75
76    /// Create tables if they don't exist.
77    fn init_schema(&self) -> Result<()> {
78        self.conn.execute_batch(
79            "
80            CREATE TABLE IF NOT EXISTS ticks (
81                exchange    VARCHAR NOT NULL,
82                symbol      VARCHAR NOT NULL,
83                ts          VARCHAR NOT NULL,
84                bid         DOUBLE,
85                ask         DOUBLE,
86                last        DOUBLE,
87                volume      DOUBLE,
88                flags       INTEGER,
89                UNIQUE (exchange, symbol, ts)
90            );
91
92            CREATE TABLE IF NOT EXISTS bars (
93                exchange    VARCHAR NOT NULL,
94                symbol      VARCHAR NOT NULL,
95                timeframe   VARCHAR NOT NULL,
96                ts          VARCHAR NOT NULL,
97                open        DOUBLE NOT NULL,
98                high        DOUBLE NOT NULL,
99                low         DOUBLE NOT NULL,
100                close       DOUBLE NOT NULL,
101                tick_vol    BIGINT DEFAULT 0,
102                volume      BIGINT DEFAULT 0,
103                spread      INTEGER DEFAULT 0,
104                UNIQUE (exchange, symbol, timeframe, ts)
105            );
106            ",
107        )?;
108        Ok(())
109    }
110
111    // ── Insert ──
112
113    /// Bulk insert ticks using INSERT OR IGNORE for dedup.
114    /// Returns the number of rows actually inserted.
115    pub fn insert_ticks(&self, ticks: &[Tick]) -> Result<usize> {
116        if ticks.is_empty() {
117            return Ok(0);
118        }
119        let exchange = &ticks[0].exchange;
120        let symbol = &ticks[0].symbol;
121        let count_before = self.count_ticks(exchange, symbol)?;
122
123        self.conn.execute_batch("BEGIN TRANSACTION")?;
124        let mut stmt = self.conn.prepare(
125            "INSERT OR IGNORE INTO ticks (exchange, symbol, ts, bid, ask, last, volume, flags)
126             VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
127        )?;
128        for tick in ticks {
129            let ts_str = ndt_to_string(&tick.ts);
130            stmt.execute(params![
131                tick.exchange,
132                tick.symbol,
133                ts_str,
134                tick.bid,
135                tick.ask,
136                tick.last,
137                tick.volume,
138                tick.flags,
139            ])?;
140        }
141        drop(stmt);
142        self.conn.execute_batch("COMMIT")?;
143
144        let count_after = self.count_ticks(exchange, symbol)?;
145        Ok((count_after - count_before) as usize)
146    }
147
148    /// Bulk insert bars using INSERT OR IGNORE for dedup.
149    /// Returns the number of rows actually inserted.
150    pub fn insert_bars(&self, bars: &[Bar]) -> Result<usize> {
151        if bars.is_empty() {
152            return Ok(0);
153        }
154        let exchange = &bars[0].exchange;
155        let symbol = &bars[0].symbol;
156        let timeframe = bars[0].timeframe.as_str();
157        let count_before = self.count_bars(exchange, symbol, timeframe)?;
158
159        self.conn.execute_batch("BEGIN TRANSACTION")?;
160        let mut stmt = self.conn.prepare(
161            "INSERT OR IGNORE INTO bars
162             (exchange, symbol, timeframe, ts, open, high, low, close, tick_vol, volume, spread)
163             VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
164        )?;
165        for bar in bars {
166            let ts_str = ndt_to_string(&bar.ts);
167            stmt.execute(params![
168                bar.exchange,
169                bar.symbol,
170                bar.timeframe.as_str(),
171                ts_str,
172                bar.open,
173                bar.high,
174                bar.low,
175                bar.close,
176                bar.tick_vol,
177                bar.volume,
178                bar.spread,
179            ])?;
180        }
181        drop(stmt);
182        self.conn.execute_batch("COMMIT")?;
183
184        let count_after = self.count_bars(exchange, symbol, timeframe)?;
185        Ok((count_after - count_before) as usize)
186    }
187
188    // ── Delete ──
189
190    /// Delete ticks matching exchange+symbol, optionally within a date range.
191    pub fn delete_ticks(
192        &self,
193        exchange: &str,
194        symbol: &str,
195        from: Option<NaiveDateTime>,
196        to: Option<NaiveDateTime>,
197    ) -> Result<usize> {
198        let mut sql = "DELETE FROM ticks WHERE exchange = ? AND symbol = ?".to_string();
199        let mut p: Vec<Box<dyn duckdb::types::ToSql>> =
200            vec![Box::new(exchange.to_string()), Box::new(symbol.to_string())];
201        if let Some(f) = from {
202            sql.push_str(" AND ts >= ?");
203            p.push(Box::new(ndt_to_string(&f)));
204        }
205        if let Some(t) = to {
206            sql.push_str(" AND ts <= ?");
207            p.push(Box::new(ndt_to_string(&t)));
208        }
209        let refs: Vec<&dyn duckdb::types::ToSql> = p.iter().map(|b| b.as_ref()).collect();
210        Ok(self.conn.execute(&sql, refs.as_slice())?)
211    }
212
213    /// Delete bars matching exchange+symbol+timeframe, optionally within a date range.
214    pub fn delete_bars(
215        &self,
216        exchange: &str,
217        symbol: &str,
218        timeframe: &str,
219        from: Option<NaiveDateTime>,
220        to: Option<NaiveDateTime>,
221    ) -> Result<usize> {
222        let mut sql =
223            "DELETE FROM bars WHERE exchange = ? AND symbol = ? AND timeframe = ?".to_string();
224        let mut p: Vec<Box<dyn duckdb::types::ToSql>> = vec![
225            Box::new(exchange.to_string()),
226            Box::new(symbol.to_string()),
227            Box::new(timeframe.to_string()),
228        ];
229        if let Some(f) = from {
230            sql.push_str(" AND ts >= ?");
231            p.push(Box::new(ndt_to_string(&f)));
232        }
233        if let Some(t) = to {
234            sql.push_str(" AND ts <= ?");
235            p.push(Box::new(ndt_to_string(&t)));
236        }
237        let refs: Vec<&dyn duckdb::types::ToSql> = p.iter().map(|b| b.as_ref()).collect();
238        Ok(self.conn.execute(&sql, refs.as_slice())?)
239    }
240
241    /// Delete ALL data (ticks + bars) for an exchange+symbol pair.
242    pub fn delete_symbol(&self, exchange: &str, symbol: &str) -> Result<(usize, usize)> {
243        let t = self.conn.execute(
244            "DELETE FROM ticks WHERE exchange = ? AND symbol = ?",
245            params![exchange, symbol],
246        )?;
247        let b = self.conn.execute(
248            "DELETE FROM bars WHERE exchange = ? AND symbol = ?",
249            params![exchange, symbol],
250        )?;
251        Ok((t, b))
252    }
253
254    /// Delete ALL data for an entire exchange.
255    pub fn delete_exchange(&self, exchange: &str) -> Result<(usize, usize)> {
256        let t = self
257            .conn
258            .execute("DELETE FROM ticks WHERE exchange = ?", params![exchange])?;
259        let b = self
260            .conn
261            .execute("DELETE FROM bars WHERE exchange = ?", params![exchange])?;
262        Ok((t, b))
263    }
264
265    // ── Query ──
266
267    /// Get summary statistics, optionally filtered by exchange and/or symbol.
268    pub fn stats(&self, exchange: Option<&str>, symbol: Option<&str>) -> Result<Vec<StatRow>> {
269        let where_clause = match (exchange, symbol) {
270            (Some(_), Some(_)) => "WHERE exchange = ? AND symbol = ?",
271            (Some(_), None) => "WHERE exchange = ?",
272            (None, Some(_)) => "WHERE symbol = ?",
273            (None, None) => "",
274        };
275
276        let sql = format!(
277            "SELECT exchange, symbol, 'tick' as data_type, COUNT(*) as count,
278                    MIN(ts) as ts_min, MAX(ts) as ts_max
279             FROM ticks {where_clause}
280             GROUP BY exchange, symbol
281             UNION ALL
282             SELECT exchange, symbol, 'bar (' || timeframe || ')' as data_type, COUNT(*) as count,
283                    MIN(ts) as ts_min, MAX(ts) as ts_max
284             FROM bars {where_clause}
285             GROUP BY exchange, symbol, timeframe
286             ORDER BY exchange, symbol, data_type"
287        );
288
289        let map_row = |row: &duckdb::Row| -> std::result::Result<StatRow, duckdb::Error> {
290            let ts_min_str: String = row.get(4)?;
291            let ts_max_str: String = row.get(5)?;
292            Ok(StatRow {
293                exchange: row.get(0)?,
294                symbol: row.get(1)?,
295                data_type: row.get(2)?,
296                count: row.get::<_, i64>(3)? as u64,
297                ts_min: string_to_ndt(&ts_min_str).unwrap_or_default(),
298                ts_max: string_to_ndt(&ts_max_str).unwrap_or_default(),
299            })
300        };
301
302        let mut stmt = self.conn.prepare(&sql)?;
303
304        // Bind params for both halves of the UNION ALL
305        let rows: Vec<StatRow> = match (exchange, symbol) {
306            (Some(ex), Some(sym)) => stmt
307                .query_map(params![ex, sym, ex, sym], map_row)?
308                .filter_map(|r| r.ok())
309                .collect(),
310            (Some(ex), None) => stmt
311                .query_map(params![ex, ex], map_row)?
312                .filter_map(|r| r.ok())
313                .collect(),
314            (None, Some(sym)) => stmt
315                .query_map(params![sym, sym], map_row)?
316                .filter_map(|r| r.ok())
317                .collect(),
318            (None, None) => stmt
319                .query_map([], map_row)?
320                .filter_map(|r| r.ok())
321                .collect(),
322        };
323        Ok(rows)
324    }
325
326    /// Query ticks with filtering and pagination.
327    pub fn query_ticks(&self, opts: &QueryOpts) -> Result<(Vec<Tick>, u64)> {
328        let total = self.count_filtered(
329            "ticks",
330            &opts.exchange,
331            &opts.symbol,
332            None,
333            opts.from,
334            opts.to,
335        )?;
336
337        let order = if opts.descending { "DESC" } else { "ASC" };
338        let (mut where_parts, mut bind_vals) = base_where(&opts.exchange, &opts.symbol);
339        append_ts_filters(&mut where_parts, &mut bind_vals, opts.from, opts.to);
340        let where_sql = where_parts.join(" AND ");
341
342        let sql = if opts.tail {
343            format!(
344                "SELECT * FROM (
345                    SELECT exchange, symbol, ts, bid, ask, last, volume, flags
346                    FROM ticks WHERE {where_sql} ORDER BY ts DESC LIMIT ?
347                 ) sub ORDER BY ts {order}"
348            )
349        } else {
350            format!(
351                "SELECT exchange, symbol, ts, bid, ask, last, volume, flags
352                 FROM ticks WHERE {where_sql} ORDER BY ts {order} LIMIT ?"
353            )
354        };
355        bind_vals.push(BVal::Int(opts.limit as i64));
356
357        let mut stmt = self.conn.prepare(&sql)?;
358        let ticks = exec_query(&mut stmt, &bind_vals, |row| {
359            let ts_str: String = row.get(2)?;
360            Ok(Tick {
361                exchange: row.get(0)?,
362                symbol: row.get(1)?,
363                ts: string_to_ndt(&ts_str).unwrap_or_default(),
364                bid: row.get(3)?,
365                ask: row.get(4)?,
366                last: row.get(5)?,
367                volume: row.get(6)?,
368                flags: row.get(7)?,
369            })
370        })?;
371        Ok((ticks, total))
372    }
373
374    /// Query bars with filtering and pagination.
375    pub fn query_bars(&self, opts: &BarQueryOpts) -> Result<(Vec<Bar>, u64)> {
376        let total = self.count_filtered(
377            "bars",
378            &opts.exchange,
379            &opts.symbol,
380            Some(&opts.timeframe),
381            opts.from,
382            opts.to,
383        )?;
384
385        let order = if opts.descending { "DESC" } else { "ASC" };
386        let (mut where_parts, mut bind_vals) = base_where(&opts.exchange, &opts.symbol);
387        where_parts.push("timeframe = ?".to_string());
388        bind_vals.push(BVal::Str(opts.timeframe.clone()));
389        append_ts_filters(&mut where_parts, &mut bind_vals, opts.from, opts.to);
390        let where_sql = where_parts.join(" AND ");
391
392        let sql = if opts.tail {
393            format!(
394                "SELECT * FROM (
395                    SELECT exchange, symbol, timeframe, ts, open, high, low, close,
396                           tick_vol, volume, spread
397                    FROM bars WHERE {where_sql} ORDER BY ts DESC LIMIT ?
398                 ) sub ORDER BY ts {order}"
399            )
400        } else {
401            format!(
402                "SELECT exchange, symbol, timeframe, ts, open, high, low, close,
403                        tick_vol, volume, spread
404                 FROM bars WHERE {where_sql} ORDER BY ts {order} LIMIT ?"
405            )
406        };
407        bind_vals.push(BVal::Int(opts.limit as i64));
408
409        let mut stmt = self.conn.prepare(&sql)?;
410        let bars = exec_query(&mut stmt, &bind_vals, |row| {
411            let tf_str: String = row.get(2)?;
412            let ts_str: String = row.get(3)?;
413            Ok(Bar {
414                exchange: row.get(0)?,
415                symbol: row.get(1)?,
416                timeframe: Timeframe::parse(&tf_str).unwrap_or(Timeframe::M1),
417                ts: string_to_ndt(&ts_str).unwrap_or_default(),
418                open: row.get(4)?,
419                high: row.get(5)?,
420                low: row.get(6)?,
421                close: row.get(7)?,
422                tick_vol: row.get(8)?,
423                volume: row.get(9)?,
424                spread: row.get(10)?,
425            })
426        })?;
427        Ok((bars, total))
428    }
429
430    /// Get database file size in bytes (None for in-memory).
431    pub fn file_size(&self) -> Option<u64> {
432        self.db_path
433            .as_ref()
434            .and_then(|p| std::fs::metadata(p).ok())
435            .map(|m| m.len())
436    }
437
438    // ── Private helpers ──
439
440    fn count_ticks(&self, exchange: &str, symbol: &str) -> Result<u64> {
441        let c: i64 = self.conn.query_row(
442            "SELECT COUNT(*) FROM ticks WHERE exchange = ? AND symbol = ?",
443            params![exchange, symbol],
444            |row| row.get(0),
445        )?;
446        Ok(c as u64)
447    }
448
449    fn count_bars(&self, exchange: &str, symbol: &str, timeframe: &str) -> Result<u64> {
450        let c: i64 = self.conn.query_row(
451            "SELECT COUNT(*) FROM bars WHERE exchange = ? AND symbol = ? AND timeframe = ?",
452            params![exchange, symbol, timeframe],
453            |row| row.get(0),
454        )?;
455        Ok(c as u64)
456    }
457
458    fn count_filtered(
459        &self,
460        table: &str,
461        exchange: &str,
462        symbol: &str,
463        timeframe: Option<&str>,
464        from: Option<NaiveDateTime>,
465        to: Option<NaiveDateTime>,
466    ) -> Result<u64> {
467        let (mut parts, mut vals) = base_where(exchange, symbol);
468        if let Some(tf) = timeframe {
469            parts.push("timeframe = ?".to_string());
470            vals.push(BVal::Str(tf.to_string()));
471        }
472        append_ts_filters(&mut parts, &mut vals, from, to);
473        let sql = format!(
474            "SELECT COUNT(*) FROM {} WHERE {}",
475            table,
476            parts.join(" AND ")
477        );
478        count_with_binds(&self.conn, &sql, &vals)
479    }
480}
481
482// ── Bind-value helpers ──
483// We use a small enum so we can build dynamic param lists at runtime.
484
485enum BVal {
486    Str(String),
487    Int(i64),
488}
489
490fn base_where(exchange: &str, symbol: &str) -> (Vec<String>, Vec<BVal>) {
491    (
492        vec!["exchange = ?".to_string(), "symbol = ?".to_string()],
493        vec![
494            BVal::Str(exchange.to_string()),
495            BVal::Str(symbol.to_string()),
496        ],
497    )
498}
499
500fn append_ts_filters(
501    parts: &mut Vec<String>,
502    vals: &mut Vec<BVal>,
503    from: Option<NaiveDateTime>,
504    to: Option<NaiveDateTime>,
505) {
506    if let Some(f) = from {
507        parts.push("ts >= ?".to_string());
508        vals.push(BVal::Str(ndt_to_string(&f)));
509    }
510    if let Some(t) = to {
511        parts.push("ts <= ?".to_string());
512        vals.push(BVal::Str(ndt_to_string(&t)));
513    }
514}
515
516/// Convert BVal slice into boxed ToSql trait objects, then ref-slice for duckdb.
517fn to_dyn_params(binds: &[BVal]) -> Vec<Box<dyn duckdb::types::ToSql>> {
518    binds
519        .iter()
520        .map(|b| -> Box<dyn duckdb::types::ToSql> {
521            match b {
522                BVal::Str(s) => Box::new(s.clone()),
523                BVal::Int(n) => Box::new(*n),
524            }
525        })
526        .collect()
527}
528
529/// Execute a SELECT with dynamic binds and map each row.
530fn exec_query<T, F>(stmt: &mut duckdb::Statement, binds: &[BVal], map_fn: F) -> Result<Vec<T>>
531where
532    F: Fn(&duckdb::Row) -> std::result::Result<T, duckdb::Error>,
533{
534    let params = to_dyn_params(binds);
535    let refs: Vec<&dyn duckdb::types::ToSql> = params.iter().map(|b| b.as_ref()).collect();
536    let rows = stmt.query_map(refs.as_slice(), &map_fn)?;
537    Ok(rows.filter_map(|r| r.ok()).collect())
538}
539
540/// Execute a COUNT(*) query with dynamic binds.
541fn count_with_binds(conn: &Connection, sql: &str, binds: &[BVal]) -> Result<u64> {
542    let params = to_dyn_params(binds);
543    let refs: Vec<&dyn duckdb::types::ToSql> = params.iter().map(|b| b.as_ref()).collect();
544    let c: i64 = conn.query_row(sql, refs.as_slice(), |row| row.get(0))?;
545    Ok(c as u64)
546}