1use std::path::Path;
2
3use chrono::NaiveDateTime;
4use duckdb::{Connection, params};
5
6use crate::error::{DataError, Result};
7use crate::models::*;
8
9const TS_FMT: &str = "%Y-%m-%d %H:%M:%S%.f";
12const TS_FMT_NO_FRAC: &str = "%Y-%m-%d %H:%M:%S";
13
14fn ndt_to_string(ndt: &NaiveDateTime) -> String {
15 ndt.format(TS_FMT).to_string()
16}
17
18fn string_to_ndt(s: &str) -> Result<NaiveDateTime> {
19 NaiveDateTime::parse_from_str(s, TS_FMT)
20 .or_else(|_| NaiveDateTime::parse_from_str(s, TS_FMT_NO_FRAC))
21 .map_err(|e| DataError::InvalidTimestamp(format!("{s}: {e}")))
22}
23
24pub struct Database {
26 conn: Connection,
27 db_path: Option<String>,
28}
29
30pub struct QueryOpts {
32 pub exchange: String,
33 pub symbol: String,
34 pub from: Option<NaiveDateTime>,
35 pub to: Option<NaiveDateTime>,
36 pub limit: usize,
37 pub tail: bool,
38 pub descending: bool,
39}
40
41pub struct BarQueryOpts {
43 pub exchange: String,
44 pub symbol: String,
45 pub timeframe: String,
46 pub from: Option<NaiveDateTime>,
47 pub to: Option<NaiveDateTime>,
48 pub limit: usize,
49 pub tail: bool,
50 pub descending: bool,
51}
52
53impl Database {
54 pub fn open(path: &Path) -> Result<Self> {
56 let conn = Connection::open(path)?;
57 let db = Self {
58 conn,
59 db_path: Some(path.display().to_string()),
60 };
61 db.init_schema()?;
62 Ok(db)
63 }
64
65 pub fn open_in_memory() -> Result<Self> {
67 let conn = Connection::open_in_memory()?;
68 let db = Self {
69 conn,
70 db_path: None,
71 };
72 db.init_schema()?;
73 Ok(db)
74 }
75
76 fn init_schema(&self) -> Result<()> {
78 self.conn.execute_batch(
79 "
80 CREATE TABLE IF NOT EXISTS ticks (
81 exchange VARCHAR NOT NULL,
82 symbol VARCHAR NOT NULL,
83 ts VARCHAR NOT NULL,
84 bid DOUBLE,
85 ask DOUBLE,
86 last DOUBLE,
87 volume DOUBLE,
88 flags INTEGER,
89 UNIQUE (exchange, symbol, ts)
90 );
91
92 CREATE TABLE IF NOT EXISTS bars (
93 exchange VARCHAR NOT NULL,
94 symbol VARCHAR NOT NULL,
95 timeframe VARCHAR NOT NULL,
96 ts VARCHAR NOT NULL,
97 open DOUBLE NOT NULL,
98 high DOUBLE NOT NULL,
99 low DOUBLE NOT NULL,
100 close DOUBLE NOT NULL,
101 tick_vol BIGINT DEFAULT 0,
102 volume BIGINT DEFAULT 0,
103 spread INTEGER DEFAULT 0,
104 UNIQUE (exchange, symbol, timeframe, ts)
105 );
106 ",
107 )?;
108 Ok(())
109 }
110
111 pub fn insert_ticks(&self, ticks: &[Tick]) -> Result<usize> {
116 if ticks.is_empty() {
117 return Ok(0);
118 }
119 let exchange = &ticks[0].exchange;
120 let symbol = &ticks[0].symbol;
121 let count_before = self.count_ticks(exchange, symbol)?;
122
123 self.conn.execute_batch("BEGIN TRANSACTION")?;
124 let mut stmt = self.conn.prepare(
125 "INSERT OR IGNORE INTO ticks (exchange, symbol, ts, bid, ask, last, volume, flags)
126 VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
127 )?;
128 for tick in ticks {
129 let ts_str = ndt_to_string(&tick.ts);
130 stmt.execute(params![
131 tick.exchange,
132 tick.symbol,
133 ts_str,
134 tick.bid,
135 tick.ask,
136 tick.last,
137 tick.volume,
138 tick.flags,
139 ])?;
140 }
141 drop(stmt);
142 self.conn.execute_batch("COMMIT")?;
143
144 let count_after = self.count_ticks(exchange, symbol)?;
145 Ok((count_after - count_before) as usize)
146 }
147
148 pub fn insert_bars(&self, bars: &[Bar]) -> Result<usize> {
151 if bars.is_empty() {
152 return Ok(0);
153 }
154 let exchange = &bars[0].exchange;
155 let symbol = &bars[0].symbol;
156 let timeframe = bars[0].timeframe.as_str();
157 let count_before = self.count_bars(exchange, symbol, timeframe)?;
158
159 self.conn.execute_batch("BEGIN TRANSACTION")?;
160 let mut stmt = self.conn.prepare(
161 "INSERT OR IGNORE INTO bars
162 (exchange, symbol, timeframe, ts, open, high, low, close, tick_vol, volume, spread)
163 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
164 )?;
165 for bar in bars {
166 let ts_str = ndt_to_string(&bar.ts);
167 stmt.execute(params![
168 bar.exchange,
169 bar.symbol,
170 bar.timeframe.as_str(),
171 ts_str,
172 bar.open,
173 bar.high,
174 bar.low,
175 bar.close,
176 bar.tick_vol,
177 bar.volume,
178 bar.spread,
179 ])?;
180 }
181 drop(stmt);
182 self.conn.execute_batch("COMMIT")?;
183
184 let count_after = self.count_bars(exchange, symbol, timeframe)?;
185 Ok((count_after - count_before) as usize)
186 }
187
188 pub fn delete_ticks(
192 &self,
193 exchange: &str,
194 symbol: &str,
195 from: Option<NaiveDateTime>,
196 to: Option<NaiveDateTime>,
197 ) -> Result<usize> {
198 let mut sql = "DELETE FROM ticks WHERE exchange = ? AND symbol = ?".to_string();
199 let mut p: Vec<Box<dyn duckdb::types::ToSql>> =
200 vec![Box::new(exchange.to_string()), Box::new(symbol.to_string())];
201 if let Some(f) = from {
202 sql.push_str(" AND ts >= ?");
203 p.push(Box::new(ndt_to_string(&f)));
204 }
205 if let Some(t) = to {
206 sql.push_str(" AND ts <= ?");
207 p.push(Box::new(ndt_to_string(&t)));
208 }
209 let refs: Vec<&dyn duckdb::types::ToSql> = p.iter().map(|b| b.as_ref()).collect();
210 Ok(self.conn.execute(&sql, refs.as_slice())?)
211 }
212
213 pub fn delete_bars(
215 &self,
216 exchange: &str,
217 symbol: &str,
218 timeframe: &str,
219 from: Option<NaiveDateTime>,
220 to: Option<NaiveDateTime>,
221 ) -> Result<usize> {
222 let mut sql =
223 "DELETE FROM bars WHERE exchange = ? AND symbol = ? AND timeframe = ?".to_string();
224 let mut p: Vec<Box<dyn duckdb::types::ToSql>> = vec![
225 Box::new(exchange.to_string()),
226 Box::new(symbol.to_string()),
227 Box::new(timeframe.to_string()),
228 ];
229 if let Some(f) = from {
230 sql.push_str(" AND ts >= ?");
231 p.push(Box::new(ndt_to_string(&f)));
232 }
233 if let Some(t) = to {
234 sql.push_str(" AND ts <= ?");
235 p.push(Box::new(ndt_to_string(&t)));
236 }
237 let refs: Vec<&dyn duckdb::types::ToSql> = p.iter().map(|b| b.as_ref()).collect();
238 Ok(self.conn.execute(&sql, refs.as_slice())?)
239 }
240
241 pub fn delete_symbol(&self, exchange: &str, symbol: &str) -> Result<(usize, usize)> {
243 let t = self.conn.execute(
244 "DELETE FROM ticks WHERE exchange = ? AND symbol = ?",
245 params![exchange, symbol],
246 )?;
247 let b = self.conn.execute(
248 "DELETE FROM bars WHERE exchange = ? AND symbol = ?",
249 params![exchange, symbol],
250 )?;
251 Ok((t, b))
252 }
253
254 pub fn delete_exchange(&self, exchange: &str) -> Result<(usize, usize)> {
256 let t = self
257 .conn
258 .execute("DELETE FROM ticks WHERE exchange = ?", params![exchange])?;
259 let b = self
260 .conn
261 .execute("DELETE FROM bars WHERE exchange = ?", params![exchange])?;
262 Ok((t, b))
263 }
264
265 pub fn stats(&self, exchange: Option<&str>, symbol: Option<&str>) -> Result<Vec<StatRow>> {
269 let where_clause = match (exchange, symbol) {
270 (Some(_), Some(_)) => "WHERE exchange = ? AND symbol = ?",
271 (Some(_), None) => "WHERE exchange = ?",
272 (None, Some(_)) => "WHERE symbol = ?",
273 (None, None) => "",
274 };
275
276 let sql = format!(
277 "SELECT exchange, symbol, 'tick' as data_type, COUNT(*) as count,
278 MIN(ts) as ts_min, MAX(ts) as ts_max
279 FROM ticks {where_clause}
280 GROUP BY exchange, symbol
281 UNION ALL
282 SELECT exchange, symbol, 'bar (' || timeframe || ')' as data_type, COUNT(*) as count,
283 MIN(ts) as ts_min, MAX(ts) as ts_max
284 FROM bars {where_clause}
285 GROUP BY exchange, symbol, timeframe
286 ORDER BY exchange, symbol, data_type"
287 );
288
289 let map_row = |row: &duckdb::Row| -> std::result::Result<StatRow, duckdb::Error> {
290 let ts_min_str: String = row.get(4)?;
291 let ts_max_str: String = row.get(5)?;
292 Ok(StatRow {
293 exchange: row.get(0)?,
294 symbol: row.get(1)?,
295 data_type: row.get(2)?,
296 count: row.get::<_, i64>(3)? as u64,
297 ts_min: string_to_ndt(&ts_min_str).unwrap_or_default(),
298 ts_max: string_to_ndt(&ts_max_str).unwrap_or_default(),
299 })
300 };
301
302 let mut stmt = self.conn.prepare(&sql)?;
303
304 let rows: Vec<StatRow> = match (exchange, symbol) {
306 (Some(ex), Some(sym)) => stmt
307 .query_map(params![ex, sym, ex, sym], map_row)?
308 .filter_map(|r| r.ok())
309 .collect(),
310 (Some(ex), None) => stmt
311 .query_map(params![ex, ex], map_row)?
312 .filter_map(|r| r.ok())
313 .collect(),
314 (None, Some(sym)) => stmt
315 .query_map(params![sym, sym], map_row)?
316 .filter_map(|r| r.ok())
317 .collect(),
318 (None, None) => stmt
319 .query_map([], map_row)?
320 .filter_map(|r| r.ok())
321 .collect(),
322 };
323 Ok(rows)
324 }
325
326 pub fn query_ticks(&self, opts: &QueryOpts) -> Result<(Vec<Tick>, u64)> {
328 let total = self.count_filtered(
329 "ticks",
330 &opts.exchange,
331 &opts.symbol,
332 None,
333 opts.from,
334 opts.to,
335 )?;
336
337 let order = if opts.descending { "DESC" } else { "ASC" };
338 let (mut where_parts, mut bind_vals) = base_where(&opts.exchange, &opts.symbol);
339 append_ts_filters(&mut where_parts, &mut bind_vals, opts.from, opts.to);
340 let where_sql = where_parts.join(" AND ");
341
342 let sql = if opts.tail {
343 format!(
344 "SELECT * FROM (
345 SELECT exchange, symbol, ts, bid, ask, last, volume, flags
346 FROM ticks WHERE {where_sql} ORDER BY ts DESC LIMIT ?
347 ) sub ORDER BY ts {order}"
348 )
349 } else {
350 format!(
351 "SELECT exchange, symbol, ts, bid, ask, last, volume, flags
352 FROM ticks WHERE {where_sql} ORDER BY ts {order} LIMIT ?"
353 )
354 };
355 bind_vals.push(BVal::Int(opts.limit as i64));
356
357 let mut stmt = self.conn.prepare(&sql)?;
358 let ticks = exec_query(&mut stmt, &bind_vals, |row| {
359 let ts_str: String = row.get(2)?;
360 Ok(Tick {
361 exchange: row.get(0)?,
362 symbol: row.get(1)?,
363 ts: string_to_ndt(&ts_str).unwrap_or_default(),
364 bid: row.get(3)?,
365 ask: row.get(4)?,
366 last: row.get(5)?,
367 volume: row.get(6)?,
368 flags: row.get(7)?,
369 })
370 })?;
371 Ok((ticks, total))
372 }
373
374 pub fn query_bars(&self, opts: &BarQueryOpts) -> Result<(Vec<Bar>, u64)> {
376 let total = self.count_filtered(
377 "bars",
378 &opts.exchange,
379 &opts.symbol,
380 Some(&opts.timeframe),
381 opts.from,
382 opts.to,
383 )?;
384
385 let order = if opts.descending { "DESC" } else { "ASC" };
386 let (mut where_parts, mut bind_vals) = base_where(&opts.exchange, &opts.symbol);
387 where_parts.push("timeframe = ?".to_string());
388 bind_vals.push(BVal::Str(opts.timeframe.clone()));
389 append_ts_filters(&mut where_parts, &mut bind_vals, opts.from, opts.to);
390 let where_sql = where_parts.join(" AND ");
391
392 let sql = if opts.tail {
393 format!(
394 "SELECT * FROM (
395 SELECT exchange, symbol, timeframe, ts, open, high, low, close,
396 tick_vol, volume, spread
397 FROM bars WHERE {where_sql} ORDER BY ts DESC LIMIT ?
398 ) sub ORDER BY ts {order}"
399 )
400 } else {
401 format!(
402 "SELECT exchange, symbol, timeframe, ts, open, high, low, close,
403 tick_vol, volume, spread
404 FROM bars WHERE {where_sql} ORDER BY ts {order} LIMIT ?"
405 )
406 };
407 bind_vals.push(BVal::Int(opts.limit as i64));
408
409 let mut stmt = self.conn.prepare(&sql)?;
410 let bars = exec_query(&mut stmt, &bind_vals, |row| {
411 let tf_str: String = row.get(2)?;
412 let ts_str: String = row.get(3)?;
413 Ok(Bar {
414 exchange: row.get(0)?,
415 symbol: row.get(1)?,
416 timeframe: Timeframe::parse(&tf_str).unwrap_or(Timeframe::M1),
417 ts: string_to_ndt(&ts_str).unwrap_or_default(),
418 open: row.get(4)?,
419 high: row.get(5)?,
420 low: row.get(6)?,
421 close: row.get(7)?,
422 tick_vol: row.get(8)?,
423 volume: row.get(9)?,
424 spread: row.get(10)?,
425 })
426 })?;
427 Ok((bars, total))
428 }
429
430 pub fn file_size(&self) -> Option<u64> {
432 self.db_path
433 .as_ref()
434 .and_then(|p| std::fs::metadata(p).ok())
435 .map(|m| m.len())
436 }
437
438 fn count_ticks(&self, exchange: &str, symbol: &str) -> Result<u64> {
441 let c: i64 = self.conn.query_row(
442 "SELECT COUNT(*) FROM ticks WHERE exchange = ? AND symbol = ?",
443 params![exchange, symbol],
444 |row| row.get(0),
445 )?;
446 Ok(c as u64)
447 }
448
449 fn count_bars(&self, exchange: &str, symbol: &str, timeframe: &str) -> Result<u64> {
450 let c: i64 = self.conn.query_row(
451 "SELECT COUNT(*) FROM bars WHERE exchange = ? AND symbol = ? AND timeframe = ?",
452 params![exchange, symbol, timeframe],
453 |row| row.get(0),
454 )?;
455 Ok(c as u64)
456 }
457
458 fn count_filtered(
459 &self,
460 table: &str,
461 exchange: &str,
462 symbol: &str,
463 timeframe: Option<&str>,
464 from: Option<NaiveDateTime>,
465 to: Option<NaiveDateTime>,
466 ) -> Result<u64> {
467 let (mut parts, mut vals) = base_where(exchange, symbol);
468 if let Some(tf) = timeframe {
469 parts.push("timeframe = ?".to_string());
470 vals.push(BVal::Str(tf.to_string()));
471 }
472 append_ts_filters(&mut parts, &mut vals, from, to);
473 let sql = format!(
474 "SELECT COUNT(*) FROM {} WHERE {}",
475 table,
476 parts.join(" AND ")
477 );
478 count_with_binds(&self.conn, &sql, &vals)
479 }
480}
481
482enum BVal {
486 Str(String),
487 Int(i64),
488}
489
490fn base_where(exchange: &str, symbol: &str) -> (Vec<String>, Vec<BVal>) {
491 (
492 vec!["exchange = ?".to_string(), "symbol = ?".to_string()],
493 vec![
494 BVal::Str(exchange.to_string()),
495 BVal::Str(symbol.to_string()),
496 ],
497 )
498}
499
500fn append_ts_filters(
501 parts: &mut Vec<String>,
502 vals: &mut Vec<BVal>,
503 from: Option<NaiveDateTime>,
504 to: Option<NaiveDateTime>,
505) {
506 if let Some(f) = from {
507 parts.push("ts >= ?".to_string());
508 vals.push(BVal::Str(ndt_to_string(&f)));
509 }
510 if let Some(t) = to {
511 parts.push("ts <= ?".to_string());
512 vals.push(BVal::Str(ndt_to_string(&t)));
513 }
514}
515
516fn to_dyn_params(binds: &[BVal]) -> Vec<Box<dyn duckdb::types::ToSql>> {
518 binds
519 .iter()
520 .map(|b| -> Box<dyn duckdb::types::ToSql> {
521 match b {
522 BVal::Str(s) => Box::new(s.clone()),
523 BVal::Int(n) => Box::new(*n),
524 }
525 })
526 .collect()
527}
528
529fn exec_query<T, F>(stmt: &mut duckdb::Statement, binds: &[BVal], map_fn: F) -> Result<Vec<T>>
531where
532 F: Fn(&duckdb::Row) -> std::result::Result<T, duckdb::Error>,
533{
534 let params = to_dyn_params(binds);
535 let refs: Vec<&dyn duckdb::types::ToSql> = params.iter().map(|b| b.as_ref()).collect();
536 let rows = stmt.query_map(refs.as_slice(), &map_fn)?;
537 Ok(rows.filter_map(|r| r.ok()).collect())
538}
539
540fn count_with_binds(conn: &Connection, sql: &str, binds: &[BVal]) -> Result<u64> {
542 let params = to_dyn_params(binds);
543 let refs: Vec<&dyn duckdb::types::ToSql> = params.iter().map(|b| b.as_ref()).collect();
544 let c: i64 = conn.query_row(sql, refs.as_slice(), |row| row.get(0))?;
545 Ok(c as u64)
546}