Skip to main content

dbkit_rs/
base_handler.rs

1use crate::DbkitError;
2use deadpool_postgres::Pool;
3#[cfg(feature = "duckdb")]
4use duckdb::Connection as DuckConnection;
5use std::sync::Arc;
6#[cfg(feature = "duckdb")]
7use std::sync::Mutex;
8#[cfg(feature = "duckdb")]
9use tokio::task;
10use tokio_postgres::Row as PgRow;
11use tokio_postgres::types::ToSql;
12use tracing::warn;
13
14// ---------------------------------------------------------------------------
15// Write operations (Postgres)
16// ---------------------------------------------------------------------------
17
18/// Unified write operation types for Postgres.
19pub enum WriteOp<'a> {
20    /// Single query with optional return.
21    Single {
22        query: &'a str,
23        params: &'a [&'a (dyn ToSql + Sync)],
24        mode: FetchMode,
25    },
26    /// Batch of DDL statements executed in a single transaction.
27    BatchDDL { queries: &'a [&'a str] },
28    /// Same query executed for each parameter set in a transaction.
29    BatchParams {
30        query: &'a str,
31        params_list: Vec<Vec<Box<dyn ToSql + Sync + Send>>>,
32    },
33}
34
35// ---------------------------------------------------------------------------
36// Read operations (DuckDB)
37// ---------------------------------------------------------------------------
38
39/// Unified read operation types for DuckDB.
40#[cfg(feature = "duckdb")]
41pub enum ReadOp<'a, T, F>
42where
43    F: Fn(&duckdb::Row<'_>) -> Result<T, DbkitError> + Send + 'static,
44    T: Send + 'static,
45{
46    /// Standard mapped query.
47    Standard {
48        query: &'a str,
49        params: Vec<DuckParam>,
50        map_fn: F,
51        mode: FetchMode,
52    },
53}
54
55/// DuckDB parameter types.
56#[cfg(feature = "duckdb")]
57#[derive(Debug, Clone)]
58pub enum DuckParam {
59    Int(i32),
60    Int64(i64),
61    Float(f64),
62    Text(String),
63    Bool(bool),
64    Null,
65}
66
67// ---------------------------------------------------------------------------
68// Query result types
69// ---------------------------------------------------------------------------
70
71/// How many rows to expect from a query.
72#[derive(Debug, Clone, Copy)]
73pub enum FetchMode {
74    None,
75    One,
76    Optional,
77    All,
78}
79
80/// Result wrapper for Postgres write queries.
81pub enum QueryResult<T> {
82    None,
83    One(T),
84    Optional(Option<T>),
85    All(Vec<T>),
86}
87
88impl<T> QueryResult<T> {
89    pub fn one(self) -> Result<T, DbkitError> {
90        match self {
91            Self::One(v) => Ok(v),
92            _ => Err(DbkitError::RowCount {
93                expected: "One".into(),
94                actual: 0,
95            }),
96        }
97    }
98
99    pub fn optional(self) -> Result<Option<T>, DbkitError> {
100        match self {
101            Self::Optional(v) => Ok(v),
102            Self::One(v) => Ok(Some(v)),
103            Self::None => Ok(None),
104            _ => Err(DbkitError::RowCount {
105                expected: "Optional".into(),
106                actual: 0,
107            }),
108        }
109    }
110
111    pub fn all(self) -> Result<Vec<T>, DbkitError> {
112        match self {
113            Self::All(v) => Ok(v),
114            _ => Err(DbkitError::RowCount {
115                expected: "All".into(),
116                actual: 0,
117            }),
118        }
119    }
120}
121
122/// Result wrapper for DuckDB read queries.
123#[cfg(feature = "duckdb")]
124pub enum ReadResult<T> {
125    Standard(QueryResult<T>),
126}
127
128#[cfg(feature = "duckdb")]
129impl<T> ReadResult<T> {
130    pub fn standard(self) -> Result<QueryResult<T>, DbkitError> {
131        match self {
132            Self::Standard(qr) => Ok(qr),
133        }
134    }
135}
136
137// ---------------------------------------------------------------------------
138// BaseHandler
139// ---------------------------------------------------------------------------
140
141/// Core query executor for Postgres writes and optionally DuckDB reads.
142pub struct BaseHandler {
143    pg_pool: Arc<Pool>,
144    #[cfg(feature = "duckdb")]
145    duck_conn: Option<Arc<Mutex<DuckConnection>>>,
146}
147
148impl BaseHandler {
149    /// Create a handler with Postgres only (for writes).
150    pub fn new(pg_pool: Arc<Pool>) -> Self {
151        Self {
152            pg_pool,
153            #[cfg(feature = "duckdb")]
154            duck_conn: None,
155        }
156    }
157
158    /// Create a handler with Postgres + DuckDB attached (for reads + writes).
159    #[cfg(feature = "duckdb")]
160    pub fn with_duckdb(
161        pg_pool: Arc<Pool>,
162        pg_connection_string: &str,
163    ) -> Result<Self, DbkitError> {
164        let duck_conn = DuckConnection::open_in_memory()
165            .map_err(|e| DbkitError::DuckDb(e.to_string()))?;
166
167        duck_conn
168            .execute_batch("INSTALL postgres; LOAD postgres;")
169            .map_err(|e| DbkitError::DuckDb(e.to_string()))?;
170
171        duck_conn
172            .execute(
173                &format!(
174                    "ATTACH '{}' AS pg (TYPE POSTGRES)",
175                    pg_connection_string
176                ),
177                [],
178            )
179            .map_err(|e| DbkitError::DuckDb(e.to_string()))?;
180
181        duck_conn
182            .execute("USE pg", [])
183            .map_err(|e| DbkitError::DuckDb(e.to_string()))?;
184
185        Ok(Self {
186            pg_pool,
187            duck_conn: Some(Arc::new(Mutex::new(duck_conn))),
188        })
189    }
190
191    /// Whether DuckDB is attached for reads.
192    pub fn has_duckdb(&self) -> bool {
193        #[cfg(feature = "duckdb")]
194        {
195            self.duck_conn.is_some()
196        }
197        #[cfg(not(feature = "duckdb"))]
198        {
199            false
200        }
201    }
202
203    /// Get a reference to the Postgres pool.
204    pub fn pool(&self) -> &Arc<Pool> {
205        &self.pg_pool
206    }
207
208    // ==================== UNIFIED WRITE ====================
209
210    /// Execute a write operation against Postgres.
211    pub async fn execute_write(
212        &self,
213        op: WriteOp<'_>,
214    ) -> Result<QueryResult<PgRow>, DbkitError> {
215        let mut client = self
216            .pg_pool
217            .get()
218            .await
219            .map_err(|e| DbkitError::Pool(e.to_string()))?;
220
221        match op {
222            WriteOp::Single {
223                query,
224                params,
225                mode,
226            } => match mode {
227                FetchMode::None => {
228                    client.execute(query, params).await?;
229                    Ok(QueryResult::None)
230                }
231                FetchMode::One => {
232                    let row = client.query_one(query, params).await?;
233                    Ok(QueryResult::One(row))
234                }
235                FetchMode::Optional => {
236                    let row = client.query_opt(query, params).await?;
237                    Ok(QueryResult::Optional(row))
238                }
239                FetchMode::All => {
240                    let rows = client.query(query, params).await?;
241                    Ok(QueryResult::All(rows))
242                }
243            },
244
245            WriteOp::BatchDDL { queries } => {
246                let transaction = client.transaction().await?;
247
248                for query in queries {
249                    transaction.execute(*query, &[]).await?;
250                }
251
252                transaction.commit().await?;
253                Ok(QueryResult::None)
254            }
255
256            WriteOp::BatchParams {
257                query,
258                params_list,
259            } => {
260                if params_list.is_empty() {
261                    return Ok(QueryResult::None);
262                }
263
264                let total = params_list.len();
265                let transaction = client.transaction().await?;
266                let stmt = transaction.prepare(query).await?;
267                let mut failed = 0usize;
268
269                let max_params = params_list.first().map(|p| p.len()).unwrap_or(0);
270                let mut params_refs: Vec<&(dyn ToSql + Sync)> =
271                    Vec::with_capacity(max_params);
272
273                for (idx, params) in params_list.iter().enumerate() {
274                    params_refs.clear();
275                    params_refs
276                        .extend(params.iter().map(|p| p.as_ref() as &(dyn ToSql + Sync)));
277                    if let Err(e) = transaction.execute(&stmt, &params_refs[..]).await {
278                        warn!("BatchParams row {}/{} failed: {:?}", idx + 1, total, e);
279                        failed += 1;
280                    }
281                }
282
283                transaction.commit().await?;
284
285                if failed > 0 {
286                    warn!(
287                        "BatchParams: {}/{} succeeded, {} failed",
288                        total - failed,
289                        total,
290                        failed
291                    );
292                }
293
294                Ok(QueryResult::None)
295            }
296        }
297    }
298
299    // ==================== UNIFIED READ ====================
300
301    /// Execute a read operation against DuckDB.
302    #[cfg(feature = "duckdb")]
303    pub async fn execute_read<T, F>(
304        &self,
305        op: ReadOp<'_, T, F>,
306    ) -> Result<ReadResult<T>, DbkitError>
307    where
308        F: Fn(&duckdb::Row<'_>) -> Result<T, DbkitError> + Send + 'static,
309        T: Send + 'static,
310    {
311        let duck_conn = self
312            .duck_conn
313            .as_ref()
314            .ok_or(DbkitError::DuckDbNotInitialized)?
315            .clone();
316
317        match op {
318            ReadOp::Standard {
319                query,
320                params,
321                map_fn,
322                mode,
323            } => {
324                let query = query.to_string();
325                let params = params.clone();
326
327                let results = task::spawn_blocking(move || {
328                    let conn = duck_conn
329                        .lock()
330                        .map_err(|e| DbkitError::LockPoisoned(e.to_string()))?;
331                    let mut stmt = conn
332                        .prepare(&query)
333                        .map_err(|e| DbkitError::DuckDb(e.to_string()))?;
334
335                    let duck_values = Self::convert_params(&params);
336                    let param_refs: Vec<&dyn duckdb::ToSql> =
337                        duck_values.iter().map(|v| v as &dyn duckdb::ToSql).collect();
338
339                    let rows = stmt
340                        .query_map(param_refs.as_slice(), |row| {
341                            map_fn(row).map_err(|e| {
342                                duckdb::Error::InvalidParameterName(e.to_string())
343                            })
344                        })
345                        .map_err(|e| DbkitError::DuckDb(e.to_string()))?;
346
347                    let mut results = Vec::new();
348                    for row in rows {
349                        results
350                            .push(row.map_err(|e| DbkitError::DuckDb(e.to_string()))?);
351                    }
352                    Ok::<Vec<T>, DbkitError>(results)
353                })
354                .await
355                .map_err(|e| DbkitError::TaskJoin(e.to_string()))??;
356
357                let query_result = match mode {
358                    FetchMode::None => QueryResult::None,
359                    FetchMode::One => {
360                        if results.len() != 1 {
361                            return Err(DbkitError::RowCount {
362                                expected: "1".into(),
363                                actual: results.len(),
364                            });
365                        }
366                        QueryResult::One(results.into_iter().next().unwrap())
367                    }
368                    FetchMode::Optional => {
369                        if results.len() > 1 {
370                            return Err(DbkitError::RowCount {
371                                expected: "0 or 1".into(),
372                                actual: results.len(),
373                            });
374                        }
375                        QueryResult::Optional(results.into_iter().next())
376                    }
377                    FetchMode::All => QueryResult::All(results),
378                };
379
380                Ok(ReadResult::Standard(query_result))
381            }
382        }
383    }
384
385    // ==================== PARAM CONVERSION ====================
386
387    #[cfg(feature = "duckdb")]
388    fn convert_params(params: &[DuckParam]) -> Vec<duckdb::types::Value> {
389        params
390            .iter()
391            .map(|p| match p {
392                DuckParam::Int(v) => duckdb::types::Value::Int(*v),
393                DuckParam::Int64(v) => duckdb::types::Value::BigInt(*v),
394                DuckParam::Float(v) => duckdb::types::Value::Double(*v),
395                DuckParam::Text(v) => duckdb::types::Value::Text(v.clone()),
396                DuckParam::Bool(v) => duckdb::types::Value::Boolean(*v),
397                DuckParam::Null => duckdb::types::Value::Null,
398            })
399            .collect()
400    }
401}