Skip to main content

hatidata_cli/
local_engine.rs

1use std::path::Path;
2
3use anyhow::{Context, Result};
4use duckdb::types::Value;
5use duckdb::Connection;
6
7/// Structured query result returned by `execute_query`.
8pub struct QueryResult {
9    pub columns: Vec<String>,
10    pub rows: Vec<Vec<String>>,
11}
12
13/// Information about a table in the local DuckDB database.
14pub struct TableInfo {
15    pub name: String,
16    pub schema: String,
17}
18
19/// Local DuckDB engine wrapper for the HatiData CLI.
20pub struct LocalEngine {
21    conn: Connection,
22}
23
24impl LocalEngine {
25    /// Open (or create) a DuckDB database at the given path.
26    pub fn open(path: &Path) -> Result<Self> {
27        let conn =
28            Connection::open(path).with_context(|| format!("Failed to open DuckDB at {}", path.display()))?;
29        Ok(Self { conn })
30    }
31
32    /// Execute a SQL query and return structured results.
33    ///
34    /// IMPORTANT DuckDB 1.4.4 API note: `column_count()` and `column_name()`
35    /// panic if called before the statement is executed. We must execute first
36    /// (via `query`), then read column metadata.
37    ///
38    /// Uses `duckdb::types::Value` for reading cell values to handle all types
39    /// correctly (the DuckDB Rust API's `row.get::<_, String>(i)` fails for
40    /// non-String types).
41    pub fn execute_query(&self, sql: &str) -> Result<QueryResult> {
42        let trimmed = sql.trim().to_uppercase();
43        let is_select = trimmed.starts_with("SELECT")
44            || trimmed.starts_with("WITH")
45            || trimmed.starts_with("SHOW")
46            || trimmed.starts_with("DESCRIBE")
47            || trimmed.starts_with("EXPLAIN")
48            || trimmed.starts_with("PRAGMA");
49
50        if !is_select {
51            // DDL/DML: execute and return empty result
52            self.conn
53                .execute_batch(sql)
54                .with_context(|| format!("Failed to execute SQL: {sql}"))?;
55            return Ok(QueryResult {
56                columns: Vec::new(),
57                rows: Vec::new(),
58            });
59        }
60
61        // SELECT-like: use query_map to execute and collect rows in one pass.
62        // query_map internally executes the statement. We collect into a Vec
63        // which drops the mutable borrow on stmt, allowing us to then call
64        // column_count()/column_name() safely.
65        let mut stmt = self
66            .conn
67            .prepare(sql)
68            .with_context(|| format!("Failed to prepare SQL: {sql}"))?;
69
70        // We don't know column_count before execution (DuckDB 1.4.4 panics).
71        // Use a dynamic approach: read values until get() fails.
72        let raw_rows: Vec<Vec<(usize, Value)>> = stmt
73            .query_map([], |row| {
74                let mut values = Vec::new();
75                let mut i = 0;
76                loop {
77                    match row.get::<_, Value>(i) {
78                        Ok(val) => {
79                            values.push((i, val));
80                            i += 1;
81                        }
82                        Err(_) => break,
83                    }
84                }
85                Ok(values)
86            })
87            .with_context(|| format!("Failed to execute query: {sql}"))?
88            .collect::<std::result::Result<Vec<_>, _>>()
89            .context("Failed to read rows")?;
90
91        // Now stmt's mutable borrow is released; column_count()/column_name() are safe
92        let column_count = stmt.column_count();
93        let column_names: Vec<String> = (0..column_count)
94            .map(|i| stmt.column_name(i).map_or("?".to_string(), |v| v.to_string()))
95            .collect();
96
97        let rows: Vec<Vec<String>> = raw_rows
98            .into_iter()
99            .map(|vals| vals.into_iter().map(|(_, v)| value_to_string(&v)).collect())
100            .collect();
101
102        Ok(QueryResult {
103            columns: column_names,
104            rows,
105        })
106    }
107
108    /// List all user tables in the database.
109    pub fn list_tables(&self) -> Result<Vec<TableInfo>> {
110        let mut stmt = self
111            .conn
112            .prepare(
113                "SELECT table_schema, table_name FROM information_schema.tables \
114                 WHERE table_schema NOT IN ('information_schema', 'pg_catalog') \
115                 AND table_type = 'BASE TABLE' \
116                 ORDER BY table_schema, table_name",
117            )
118            .context("Failed to query information_schema")?;
119
120        let rows = stmt
121            .query_map([], |row| {
122                let schema: Value = row.get(0)?;
123                let name: Value = row.get(1)?;
124                Ok(TableInfo {
125                    schema: value_to_string(&schema),
126                    name: value_to_string(&name),
127                })
128            })
129            .context("Failed to list tables")?;
130
131        let mut tables = Vec::new();
132        for row in rows {
133            tables.push(row.context("Failed to read table info")?);
134        }
135
136        Ok(tables)
137    }
138
139    /// Get the row count for a specific table.
140    pub fn table_row_count(&self, table: &str) -> Result<u64> {
141        // Validate table name to prevent SQL injection (alphanumeric + underscore only)
142        if !table
143            .chars()
144            .all(|c| c.is_alphanumeric() || c == '_')
145        {
146            anyhow::bail!("Invalid table name: {table}");
147        }
148
149        let sql = format!("SELECT COUNT(*) FROM \"{table}\"");
150        let mut stmt = self.conn.prepare(&sql)?;
151        let mut rows = stmt.query([])?;
152
153        if let Some(row) = rows.next()? {
154            let value: Value = row.get(0)?;
155            match value {
156                Value::BigInt(n) => Ok(n as u64),
157                Value::Int(n) => Ok(n as u64),
158                Value::HugeInt(n) => Ok(n as u64),
159                _ => Ok(0),
160            }
161        } else {
162            Ok(0)
163        }
164    }
165
166    /// Export a table to a Parquet file.
167    pub fn export_table_parquet(&self, table: &str, output: &Path) -> Result<()> {
168        // Validate table name
169        if !table
170            .chars()
171            .all(|c| c.is_alphanumeric() || c == '_')
172        {
173            anyhow::bail!("Invalid table name: {table}");
174        }
175
176        let output_str = output.display().to_string();
177        let sql = format!("COPY \"{table}\" TO '{output_str}' (FORMAT PARQUET)");
178        self.conn
179            .execute_batch(&sql)
180            .with_context(|| format!("Failed to export {table} to parquet"))?;
181
182        Ok(())
183    }
184}
185
186/// Convert a DuckDB `Value` to a display string.
187fn value_to_string(value: &Value) -> String {
188    match value {
189        Value::Null => "NULL".to_string(),
190        Value::Boolean(b) => b.to_string(),
191        Value::TinyInt(n) => n.to_string(),
192        Value::SmallInt(n) => n.to_string(),
193        Value::Int(n) => n.to_string(),
194        Value::BigInt(n) => n.to_string(),
195        Value::HugeInt(n) => n.to_string(),
196        Value::UTinyInt(n) => n.to_string(),
197        Value::USmallInt(n) => n.to_string(),
198        Value::UInt(n) => n.to_string(),
199        Value::UBigInt(n) => n.to_string(),
200        Value::Float(f) => f.to_string(),
201        Value::Double(f) => f.to_string(),
202        Value::Text(s) => s.clone(),
203        Value::Blob(b) => format!("<blob {} bytes>", b.len()),
204        Value::Date32(d) => d.to_string(),
205        Value::Time64(..) => format!("{value:?}"),
206        Value::Timestamp(..) => format!("{value:?}"),
207        Value::Interval { .. } => format!("{value:?}"),
208        _ => format!("{value:?}"),
209    }
210}