Skip to main content

dbkit_rs/
base_handler.rs

1use crate::DbkitError;
2use deadpool_postgres::Pool;
3#[cfg(feature = "duckdb")]
4use duckdb::Connection as DuckConnection;
5use std::sync::Arc;
6#[cfg(feature = "duckdb")]
7use std::sync::Mutex;
8#[cfg(feature = "duckdb")]
9use tokio::task;
10use tokio_postgres::Row as PgRow;
11use tokio_postgres::types::ToSql;
12use tracing::warn;
13use unicode_normalization::UnicodeNormalization;
14
15#[cfg(feature = "duckdb")]
16pub use duckdb::arrow::record_batch::RecordBatch;
17
18// ---------------------------------------------------------------------------
19// Write operations (Postgres)
20// ---------------------------------------------------------------------------
21
22/// Unified write operation types for Postgres.
23pub enum WriteOp<'a> {
24    /// Single query with optional return.
25    Single {
26        query: &'a str,
27        params: &'a [&'a (dyn ToSql + Sync)],
28        mode: FetchMode,
29    },
30    /// Batch of DDL statements executed in a single transaction.
31    BatchDDL { queries: &'a [&'a str] },
32    /// Same query executed for each parameter set in a transaction.
33    BatchParams {
34        query: &'a str,
35        params_list: Vec<Vec<Box<dyn ToSql + Sync + Send>>>,
36    },
37}
38
39// ---------------------------------------------------------------------------
40// Read operations (DuckDB)
41// ---------------------------------------------------------------------------
42
43/// Unified read operation types for DuckDB.
44#[cfg(feature = "duckdb")]
45pub enum ReadOp<'a, T, F>
46where
47    F: Fn(&duckdb::Row<'_>) -> Result<T, DbkitError> + Send + 'static,
48    T: Send + 'static,
49{
50    /// Standard mapped query.
51    Standard {
52        query: &'a str,
53        params: Vec<DuckParam>,
54        map_fn: F,
55        mode: FetchMode,
56    },
57    /// Arrow columnar query — returns `Vec<RecordBatch>`.
58    Arrow {
59        query: &'a str,
60        params: Vec<DuckParam>,
61    },
62}
63
64#[cfg(feature = "duckdb")]
65type NoopMapFn = fn(&duckdb::Row<'_>) -> Result<(), DbkitError>;
66
67#[cfg(feature = "duckdb")]
68impl<'a> ReadOp<'a, (), NoopMapFn> {
69    /// Convenience constructor for Arrow reads without needing type annotations.
70    ///
71    /// ```ignore
72    /// handler.execute_read(ReadOp::arrow("SELECT * FROM t", vec![])).await?.arrow()?
73    /// ```
74    pub fn arrow(query: &'a str, params: Vec<DuckParam>) -> Self {
75        ReadOp::Arrow { query, params }
76    }
77}
78
79/// DuckDB parameter types (including optional variants).
80#[cfg(feature = "duckdb")]
81#[derive(Debug, Clone)]
82pub enum DuckParam {
83    Int(i32),
84    Int64(i64),
85    Float(f64),
86    Text(String),
87    Bool(bool),
88    Null,
89    OptInt(Option<i32>),
90    OptInt64(Option<i64>),
91    OptFloat(Option<f64>),
92    OptText(Option<String>),
93    OptBool(Option<bool>),
94}
95
96// ---------------------------------------------------------------------------
97// Query result types
98// ---------------------------------------------------------------------------
99
100/// How many rows to expect from a query.
101#[derive(Debug, Clone, Copy)]
102pub enum FetchMode {
103    None,
104    One,
105    Optional,
106    All,
107}
108
109/// Result wrapper for Postgres write queries.
110pub enum QueryResult<T> {
111    None,
112    One(T),
113    Optional(Option<T>),
114    All(Vec<T>),
115}
116
117impl<T> QueryResult<T> {
118    pub fn one(self) -> Result<T, DbkitError> {
119        match self {
120            Self::One(v) => Ok(v),
121            _ => Err(DbkitError::RowCount {
122                expected: "One".into(),
123                actual: 0,
124            }),
125        }
126    }
127
128    pub fn optional(self) -> Result<Option<T>, DbkitError> {
129        match self {
130            Self::Optional(v) => Ok(v),
131            Self::One(v) => Ok(Some(v)),
132            Self::None => Ok(None),
133            _ => Err(DbkitError::RowCount {
134                expected: "Optional".into(),
135                actual: 0,
136            }),
137        }
138    }
139
140    pub fn all(self) -> Result<Vec<T>, DbkitError> {
141        match self {
142            Self::All(v) => Ok(v),
143            _ => Err(DbkitError::RowCount {
144                expected: "All".into(),
145                actual: 0,
146            }),
147        }
148    }
149}
150
151/// Result wrapper for DuckDB read queries.
152#[cfg(feature = "duckdb")]
153pub enum ReadResult<T> {
154    Standard(QueryResult<T>),
155    Arrow(Vec<RecordBatch>),
156}
157
158#[cfg(feature = "duckdb")]
159impl<T> ReadResult<T> {
160    pub fn standard(self) -> Result<QueryResult<T>, DbkitError> {
161        match self {
162            Self::Standard(qr) => Ok(qr),
163            _ => Err(DbkitError::RowCount {
164                expected: "Standard".into(),
165                actual: 0,
166            }),
167        }
168    }
169
170    pub fn arrow(self) -> Result<Vec<RecordBatch>, DbkitError> {
171        match self {
172            Self::Arrow(batches) => Ok(batches),
173            _ => Err(DbkitError::RowCount {
174                expected: "Arrow".into(),
175                actual: 0,
176            }),
177        }
178    }
179}
180
181// ---------------------------------------------------------------------------
182// BaseHandler
183// ---------------------------------------------------------------------------
184
185/// Core query executor for Postgres writes and optionally DuckDB reads.
186pub struct BaseHandler {
187    pg_pool: Arc<Pool>,
188    #[cfg(feature = "duckdb")]
189    duck_conn: Option<Arc<Mutex<DuckConnection>>>,
190}
191
192impl BaseHandler {
193    /// Create a handler with Postgres only (for writes).
194    pub fn new(pg_pool: Arc<Pool>) -> Self {
195        Self {
196            pg_pool,
197            #[cfg(feature = "duckdb")]
198            duck_conn: None,
199        }
200    }
201
202    /// Create a handler with Postgres + DuckDB attached (for reads + writes).
203    #[cfg(feature = "duckdb")]
204    pub fn with_duckdb(
205        pg_pool: Arc<Pool>,
206        pg_connection_string: &str,
207    ) -> Result<Self, DbkitError> {
208        let duck_conn = DuckConnection::open_in_memory()
209            .map_err(|e| DbkitError::DuckDb(e.to_string()))?;
210
211        duck_conn
212            .execute_batch("INSTALL postgres; LOAD postgres;")
213            .map_err(|e| DbkitError::DuckDb(e.to_string()))?;
214
215        duck_conn
216            .execute(
217                &format!(
218                    "ATTACH '{}' AS pg (TYPE POSTGRES)",
219                    pg_connection_string
220                ),
221                [],
222            )
223            .map_err(|e| DbkitError::DuckDb(e.to_string()))?;
224
225        duck_conn
226            .execute("USE pg", [])
227            .map_err(|e| DbkitError::DuckDb(e.to_string()))?;
228
229        Ok(Self {
230            pg_pool,
231            duck_conn: Some(Arc::new(Mutex::new(duck_conn))),
232        })
233    }
234
235    /// Whether DuckDB is attached for reads.
236    pub fn has_duckdb(&self) -> bool {
237        #[cfg(feature = "duckdb")]
238        {
239            self.duck_conn.is_some()
240        }
241        #[cfg(not(feature = "duckdb"))]
242        {
243            false
244        }
245    }
246
247    /// Get a reference to the Postgres pool.
248    pub fn pool(&self) -> &Arc<Pool> {
249        &self.pg_pool
250    }
251
252    /// Unicode NFD normalization — decomposes characters then lowercases.
253    /// Useful for matching names with different Unicode representations.
254    pub fn normalize_name(name: &str) -> String {
255        name.nfd().collect::<String>().to_lowercase()
256    }
257
258    // ==================== UNIFIED WRITE ====================
259
260    /// Execute a write operation against Postgres.
261    pub async fn execute_write(
262        &self,
263        op: WriteOp<'_>,
264    ) -> Result<QueryResult<PgRow>, DbkitError> {
265        let mut client = self
266            .pg_pool
267            .get()
268            .await
269            .map_err(|e| DbkitError::Pool(e.to_string()))?;
270
271        match op {
272            WriteOp::Single {
273                query,
274                params,
275                mode,
276            } => match mode {
277                FetchMode::None => {
278                    client.execute(query, params).await?;
279                    Ok(QueryResult::None)
280                }
281                FetchMode::One => {
282                    let row = client.query_one(query, params).await?;
283                    Ok(QueryResult::One(row))
284                }
285                FetchMode::Optional => {
286                    let row = client.query_opt(query, params).await?;
287                    Ok(QueryResult::Optional(row))
288                }
289                FetchMode::All => {
290                    let rows = client.query(query, params).await?;
291                    Ok(QueryResult::All(rows))
292                }
293            },
294
295            WriteOp::BatchDDL { queries } => {
296                let transaction = client.transaction().await?;
297
298                for query in queries {
299                    transaction.execute(*query, &[]).await?;
300                }
301
302                transaction.commit().await?;
303                Ok(QueryResult::None)
304            }
305
306            WriteOp::BatchParams {
307                query,
308                params_list,
309            } => {
310                if params_list.is_empty() {
311                    return Ok(QueryResult::None);
312                }
313
314                let total = params_list.len();
315                let transaction = client.transaction().await?;
316                let stmt = transaction.prepare(query).await?;
317                let mut failed = 0usize;
318
319                let max_params = params_list.first().map(|p| p.len()).unwrap_or(0);
320                let mut params_refs: Vec<&(dyn ToSql + Sync)> =
321                    Vec::with_capacity(max_params);
322
323                for (idx, params) in params_list.iter().enumerate() {
324                    params_refs.clear();
325                    params_refs
326                        .extend(params.iter().map(|p| p.as_ref() as &(dyn ToSql + Sync)));
327                    if let Err(e) = transaction.execute(&stmt, &params_refs[..]).await {
328                        warn!("BatchParams row {}/{} failed: {:?}", idx + 1, total, e);
329                        failed += 1;
330                    }
331                }
332
333                transaction.commit().await?;
334
335                if failed > 0 {
336                    warn!(
337                        "BatchParams: {}/{} succeeded, {} failed",
338                        total - failed,
339                        total,
340                        failed
341                    );
342                }
343
344                Ok(QueryResult::None)
345            }
346        }
347    }
348
349    // ==================== UNIFIED READ ====================
350
351    /// Execute a read operation against DuckDB.
352    #[cfg(feature = "duckdb")]
353    pub async fn execute_read<T, F>(
354        &self,
355        op: ReadOp<'_, T, F>,
356    ) -> Result<ReadResult<T>, DbkitError>
357    where
358        F: Fn(&duckdb::Row<'_>) -> Result<T, DbkitError> + Send + 'static,
359        T: Send + 'static,
360    {
361        let duck_conn = self
362            .duck_conn
363            .as_ref()
364            .ok_or(DbkitError::DuckDbNotInitialized)?
365            .clone();
366
367        match op {
368            ReadOp::Standard {
369                query,
370                params,
371                map_fn,
372                mode,
373            } => {
374                let query = query.to_string();
375                let params = params.clone();
376
377                let results = task::spawn_blocking(move || {
378                    let conn = duck_conn
379                        .lock()
380                        .map_err(|e| DbkitError::LockPoisoned(e.to_string()))?;
381                    let mut stmt = conn
382                        .prepare(&query)
383                        .map_err(|e| DbkitError::DuckDb(e.to_string()))?;
384
385                    let duck_values = Self::convert_params(&params);
386                    let param_refs: Vec<&dyn duckdb::ToSql> =
387                        duck_values.iter().map(|v| v as &dyn duckdb::ToSql).collect();
388
389                    let rows = stmt
390                        .query_map(param_refs.as_slice(), |row| {
391                            map_fn(row).map_err(|e| {
392                                duckdb::Error::InvalidParameterName(e.to_string())
393                            })
394                        })
395                        .map_err(|e| DbkitError::DuckDb(e.to_string()))?;
396
397                    let mut results = Vec::new();
398                    for row in rows {
399                        results
400                            .push(row.map_err(|e| DbkitError::DuckDb(e.to_string()))?);
401                    }
402                    Ok::<Vec<T>, DbkitError>(results)
403                })
404                .await
405                .map_err(|e| DbkitError::TaskJoin(e.to_string()))??;
406
407                let query_result = match mode {
408                    FetchMode::None => QueryResult::None,
409                    FetchMode::One => {
410                        if results.len() != 1 {
411                            return Err(DbkitError::RowCount {
412                                expected: "1".into(),
413                                actual: results.len(),
414                            });
415                        }
416                        QueryResult::One(results.into_iter().next().unwrap())
417                    }
418                    FetchMode::Optional => {
419                        if results.len() > 1 {
420                            return Err(DbkitError::RowCount {
421                                expected: "0 or 1".into(),
422                                actual: results.len(),
423                            });
424                        }
425                        QueryResult::Optional(results.into_iter().next())
426                    }
427                    FetchMode::All => QueryResult::All(results),
428                };
429
430                Ok(ReadResult::Standard(query_result))
431            }
432
433            ReadOp::Arrow { query, params } => {
434                let query = query.to_string();
435                let params = params.clone();
436
437                let batches = task::spawn_blocking(move || {
438                    let conn = duck_conn
439                        .lock()
440                        .map_err(|e| DbkitError::LockPoisoned(e.to_string()))?;
441                    let mut stmt = conn
442                        .prepare(&query)
443                        .map_err(|e| DbkitError::DuckDb(e.to_string()))?;
444
445                    let duck_values = Self::convert_params(&params);
446                    let param_refs: Vec<&dyn duckdb::ToSql> =
447                        duck_values.iter().map(|v| v as &dyn duckdb::ToSql).collect();
448
449                    let arrow_iter = stmt
450                        .query_arrow(param_refs.as_slice())
451                        .map_err(|e| DbkitError::DuckDb(e.to_string()))?;
452
453                    Ok::<Vec<RecordBatch>, DbkitError>(arrow_iter.collect())
454                })
455                .await
456                .map_err(|e| DbkitError::TaskJoin(e.to_string()))??;
457
458                Ok(ReadResult::Arrow(batches))
459            }
460        }
461    }
462
463    // ==================== SYNC (PG -> DuckDB) ====================
464
465    /// Copy entire tables from Postgres into DuckDB local memory for analytical reads.
466    ///
467    /// Creates `memory.main.{table}` for each table, replacing any existing copy.
468    #[cfg(feature = "duckdb")]
469    pub async fn sync_tables(&self, tables: &[&str]) -> Result<(), DbkitError> {
470        let duck_conn = self
471            .duck_conn
472            .as_ref()
473            .ok_or(DbkitError::DuckDbNotInitialized)?
474            .clone();
475
476        let tables: Vec<String> = tables.iter().map(|t| t.to_string()).collect();
477
478        task::spawn_blocking(move || {
479            let conn = duck_conn
480                .lock()
481                .map_err(|e| DbkitError::LockPoisoned(e.to_string()))?;
482
483            for table in &tables {
484                let sql = format!(
485                    "CREATE OR REPLACE TABLE memory.main.{table} AS SELECT * FROM pg.public.{table}"
486                );
487                conn.execute(&sql, [])
488                    .map_err(|e| DbkitError::DuckDb(format!("sync {table}: {e}")))?;
489            }
490            Ok(())
491        })
492        .await
493        .map_err(|e| DbkitError::TaskJoin(e.to_string()))?
494    }
495
496    /// Copy a filtered subset of a Postgres table into DuckDB local memory.
497    ///
498    /// The `filter` is a SQL WHERE clause (without the `WHERE` keyword).
499    /// Creates `memory.main.{table}`, replacing any existing copy.
500    #[cfg(feature = "duckdb")]
501    pub async fn sync_table_filtered(
502        &self,
503        table: &str,
504        filter: &str,
505        params: &[DuckParam],
506    ) -> Result<(), DbkitError> {
507        let duck_conn = self
508            .duck_conn
509            .as_ref()
510            .ok_or(DbkitError::DuckDbNotInitialized)?
511            .clone();
512
513        let table = table.to_string();
514        let filter = filter.to_string();
515        let params = params.to_vec();
516
517        task::spawn_blocking(move || {
518            let conn = duck_conn
519                .lock()
520                .map_err(|e| DbkitError::LockPoisoned(e.to_string()))?;
521
522            let sql = format!(
523                "CREATE OR REPLACE TABLE memory.main.{table} AS SELECT * FROM pg.public.{table} WHERE {filter}"
524            );
525
526            let duck_values = Self::convert_params(&params);
527            let param_refs: Vec<&dyn duckdb::ToSql> =
528                duck_values.iter().map(|v| v as &dyn duckdb::ToSql).collect();
529
530            conn.execute(&sql, param_refs.as_slice())
531                .map_err(|e| DbkitError::DuckDb(format!("sync_filtered {table}: {e}")))?;
532
533            Ok(())
534        })
535        .await
536        .map_err(|e| DbkitError::TaskJoin(e.to_string()))?
537    }
538
539    // ==================== PARAM CONVERSION ====================
540
541    #[cfg(feature = "duckdb")]
542    fn convert_params(params: &[DuckParam]) -> Vec<duckdb::types::Value> {
543        params
544            .iter()
545            .map(|p| match p {
546                DuckParam::Int(v) => duckdb::types::Value::Int(*v),
547                DuckParam::Int64(v) => duckdb::types::Value::BigInt(*v),
548                DuckParam::Float(v) => duckdb::types::Value::Double(*v),
549                DuckParam::Text(v) => duckdb::types::Value::Text(v.clone()),
550                DuckParam::Bool(v) => duckdb::types::Value::Boolean(*v),
551                DuckParam::Null => duckdb::types::Value::Null,
552                DuckParam::OptInt(v) => match v {
553                    Some(val) => duckdb::types::Value::Int(*val),
554                    None => duckdb::types::Value::Null,
555                },
556                DuckParam::OptInt64(v) => match v {
557                    Some(val) => duckdb::types::Value::BigInt(*val),
558                    None => duckdb::types::Value::Null,
559                },
560                DuckParam::OptFloat(v) => match v {
561                    Some(val) => duckdb::types::Value::Double(*val),
562                    None => duckdb::types::Value::Null,
563                },
564                DuckParam::OptText(v) => match v {
565                    Some(val) => duckdb::types::Value::Text(val.clone()),
566                    None => duckdb::types::Value::Null,
567                },
568                DuckParam::OptBool(v) => match v {
569                    Some(val) => duckdb::types::Value::Boolean(*val),
570                    None => duckdb::types::Value::Null,
571                },
572            })
573            .collect()
574    }
575}