1use crate::DbkitError;
2use deadpool_postgres::Pool;
3#[cfg(feature = "duckdb")]
4use duckdb::Connection as DuckConnection;
5use std::sync::Arc;
6#[cfg(feature = "duckdb")]
7use std::sync::Mutex;
8#[cfg(feature = "duckdb")]
9use tokio::task;
10use tokio_postgres::Row as PgRow;
11use tokio_postgres::types::ToSql;
12use tracing::warn;
13use unicode_normalization::UnicodeNormalization;
14
15#[cfg(feature = "duckdb")]
16pub use duckdb::arrow::record_batch::RecordBatch;
17
18pub enum WriteOp<'a> {
24 Single {
26 query: &'a str,
27 params: &'a [&'a (dyn ToSql + Sync)],
28 mode: FetchMode,
29 },
30 BatchDDL { queries: &'a [&'a str] },
32 BatchParams {
34 query: &'a str,
35 params_list: Vec<Vec<Box<dyn ToSql + Sync + Send>>>,
36 },
37}
38
39#[cfg(feature = "duckdb")]
45pub enum ReadOp<'a, T, F>
46where
47 F: Fn(&duckdb::Row<'_>) -> Result<T, DbkitError> + Send + 'static,
48 T: Send + 'static,
49{
50 Standard {
52 query: &'a str,
53 params: Vec<DuckParam>,
54 map_fn: F,
55 mode: FetchMode,
56 },
57 Arrow {
59 query: &'a str,
60 params: Vec<DuckParam>,
61 },
62}
63
64#[cfg(feature = "duckdb")]
65type NoopMapFn = fn(&duckdb::Row<'_>) -> Result<(), DbkitError>;
66
67#[cfg(feature = "duckdb")]
68impl<'a> ReadOp<'a, (), NoopMapFn> {
69 pub fn arrow(query: &'a str, params: Vec<DuckParam>) -> Self {
75 ReadOp::Arrow { query, params }
76 }
77}
78
79#[cfg(feature = "duckdb")]
81#[derive(Debug, Clone)]
82pub enum DuckParam {
83 Int(i32),
84 Int64(i64),
85 Float(f64),
86 Text(String),
87 Bool(bool),
88 Null,
89 OptInt(Option<i32>),
90 OptInt64(Option<i64>),
91 OptFloat(Option<f64>),
92 OptText(Option<String>),
93 OptBool(Option<bool>),
94}
95
96#[derive(Debug, Clone, Copy)]
102pub enum FetchMode {
103 None,
104 One,
105 Optional,
106 All,
107}
108
109pub enum QueryResult<T> {
111 None,
112 One(T),
113 Optional(Option<T>),
114 All(Vec<T>),
115}
116
117impl<T> QueryResult<T> {
118 pub fn one(self) -> Result<T, DbkitError> {
119 match self {
120 Self::One(v) => Ok(v),
121 _ => Err(DbkitError::RowCount {
122 expected: "One".into(),
123 actual: 0,
124 }),
125 }
126 }
127
128 pub fn optional(self) -> Result<Option<T>, DbkitError> {
129 match self {
130 Self::Optional(v) => Ok(v),
131 Self::One(v) => Ok(Some(v)),
132 Self::None => Ok(None),
133 _ => Err(DbkitError::RowCount {
134 expected: "Optional".into(),
135 actual: 0,
136 }),
137 }
138 }
139
140 pub fn all(self) -> Result<Vec<T>, DbkitError> {
141 match self {
142 Self::All(v) => Ok(v),
143 _ => Err(DbkitError::RowCount {
144 expected: "All".into(),
145 actual: 0,
146 }),
147 }
148 }
149}
150
151#[cfg(feature = "duckdb")]
153pub enum ReadResult<T> {
154 Standard(QueryResult<T>),
155 Arrow(Vec<RecordBatch>),
156}
157
158#[cfg(feature = "duckdb")]
159impl<T> ReadResult<T> {
160 pub fn standard(self) -> Result<QueryResult<T>, DbkitError> {
161 match self {
162 Self::Standard(qr) => Ok(qr),
163 _ => Err(DbkitError::RowCount {
164 expected: "Standard".into(),
165 actual: 0,
166 }),
167 }
168 }
169
170 pub fn arrow(self) -> Result<Vec<RecordBatch>, DbkitError> {
171 match self {
172 Self::Arrow(batches) => Ok(batches),
173 _ => Err(DbkitError::RowCount {
174 expected: "Arrow".into(),
175 actual: 0,
176 }),
177 }
178 }
179}
180
181pub struct BaseHandler {
187 pg_pool: Arc<Pool>,
188 #[cfg(feature = "duckdb")]
189 duck_conn: Option<Arc<Mutex<DuckConnection>>>,
190}
191
192impl BaseHandler {
193 pub fn new(pg_pool: Arc<Pool>) -> Self {
195 Self {
196 pg_pool,
197 #[cfg(feature = "duckdb")]
198 duck_conn: None,
199 }
200 }
201
202 #[cfg(feature = "duckdb")]
204 pub fn with_duckdb(
205 pg_pool: Arc<Pool>,
206 pg_connection_string: &str,
207 ) -> Result<Self, DbkitError> {
208 let duck_conn = DuckConnection::open_in_memory()
209 .map_err(|e| DbkitError::DuckDb(e.to_string()))?;
210
211 duck_conn
212 .execute_batch("INSTALL postgres; LOAD postgres;")
213 .map_err(|e| DbkitError::DuckDb(e.to_string()))?;
214
215 duck_conn
216 .execute(
217 &format!(
218 "ATTACH '{}' AS pg (TYPE POSTGRES)",
219 pg_connection_string
220 ),
221 [],
222 )
223 .map_err(|e| DbkitError::DuckDb(e.to_string()))?;
224
225 duck_conn
226 .execute("USE pg", [])
227 .map_err(|e| DbkitError::DuckDb(e.to_string()))?;
228
229 Ok(Self {
230 pg_pool,
231 duck_conn: Some(Arc::new(Mutex::new(duck_conn))),
232 })
233 }
234
235 pub fn has_duckdb(&self) -> bool {
237 #[cfg(feature = "duckdb")]
238 {
239 self.duck_conn.is_some()
240 }
241 #[cfg(not(feature = "duckdb"))]
242 {
243 false
244 }
245 }
246
247 pub fn pool(&self) -> &Arc<Pool> {
249 &self.pg_pool
250 }
251
252 pub fn normalize_name(name: &str) -> String {
255 name.nfd().collect::<String>().to_lowercase()
256 }
257
258 pub async fn execute_write(
262 &self,
263 op: WriteOp<'_>,
264 ) -> Result<QueryResult<PgRow>, DbkitError> {
265 let mut client = self
266 .pg_pool
267 .get()
268 .await
269 .map_err(|e| DbkitError::Pool(e.to_string()))?;
270
271 match op {
272 WriteOp::Single {
273 query,
274 params,
275 mode,
276 } => match mode {
277 FetchMode::None => {
278 client.execute(query, params).await?;
279 Ok(QueryResult::None)
280 }
281 FetchMode::One => {
282 let row = client.query_one(query, params).await?;
283 Ok(QueryResult::One(row))
284 }
285 FetchMode::Optional => {
286 let row = client.query_opt(query, params).await?;
287 Ok(QueryResult::Optional(row))
288 }
289 FetchMode::All => {
290 let rows = client.query(query, params).await?;
291 Ok(QueryResult::All(rows))
292 }
293 },
294
295 WriteOp::BatchDDL { queries } => {
296 let transaction = client.transaction().await?;
297
298 for query in queries {
299 transaction.execute(*query, &[]).await?;
300 }
301
302 transaction.commit().await?;
303 Ok(QueryResult::None)
304 }
305
306 WriteOp::BatchParams {
307 query,
308 params_list,
309 } => {
310 if params_list.is_empty() {
311 return Ok(QueryResult::None);
312 }
313
314 let total = params_list.len();
315 let transaction = client.transaction().await?;
316 let stmt = transaction.prepare(query).await?;
317 let mut failed = 0usize;
318
319 let max_params = params_list.first().map(|p| p.len()).unwrap_or(0);
320 let mut params_refs: Vec<&(dyn ToSql + Sync)> =
321 Vec::with_capacity(max_params);
322
323 for (idx, params) in params_list.iter().enumerate() {
324 params_refs.clear();
325 params_refs
326 .extend(params.iter().map(|p| p.as_ref() as &(dyn ToSql + Sync)));
327 if let Err(e) = transaction.execute(&stmt, ¶ms_refs[..]).await {
328 warn!("BatchParams row {}/{} failed: {:?}", idx + 1, total, e);
329 failed += 1;
330 }
331 }
332
333 transaction.commit().await?;
334
335 if failed > 0 {
336 warn!(
337 "BatchParams: {}/{} succeeded, {} failed",
338 total - failed,
339 total,
340 failed
341 );
342 }
343
344 Ok(QueryResult::None)
345 }
346 }
347 }
348
349 #[cfg(feature = "duckdb")]
353 pub async fn execute_read<T, F>(
354 &self,
355 op: ReadOp<'_, T, F>,
356 ) -> Result<ReadResult<T>, DbkitError>
357 where
358 F: Fn(&duckdb::Row<'_>) -> Result<T, DbkitError> + Send + 'static,
359 T: Send + 'static,
360 {
361 let duck_conn = self
362 .duck_conn
363 .as_ref()
364 .ok_or(DbkitError::DuckDbNotInitialized)?
365 .clone();
366
367 match op {
368 ReadOp::Standard {
369 query,
370 params,
371 map_fn,
372 mode,
373 } => {
374 let query = query.to_string();
375 let params = params.clone();
376
377 let results = task::spawn_blocking(move || {
378 let conn = duck_conn
379 .lock()
380 .map_err(|e| DbkitError::LockPoisoned(e.to_string()))?;
381 let mut stmt = conn
382 .prepare(&query)
383 .map_err(|e| DbkitError::DuckDb(e.to_string()))?;
384
385 let duck_values = Self::convert_params(¶ms);
386 let param_refs: Vec<&dyn duckdb::ToSql> =
387 duck_values.iter().map(|v| v as &dyn duckdb::ToSql).collect();
388
389 let rows = stmt
390 .query_map(param_refs.as_slice(), |row| {
391 map_fn(row).map_err(|e| {
392 duckdb::Error::InvalidParameterName(e.to_string())
393 })
394 })
395 .map_err(|e| DbkitError::DuckDb(e.to_string()))?;
396
397 let mut results = Vec::new();
398 for row in rows {
399 results
400 .push(row.map_err(|e| DbkitError::DuckDb(e.to_string()))?);
401 }
402 Ok::<Vec<T>, DbkitError>(results)
403 })
404 .await
405 .map_err(|e| DbkitError::TaskJoin(e.to_string()))??;
406
407 let query_result = match mode {
408 FetchMode::None => QueryResult::None,
409 FetchMode::One => {
410 if results.len() != 1 {
411 return Err(DbkitError::RowCount {
412 expected: "1".into(),
413 actual: results.len(),
414 });
415 }
416 QueryResult::One(results.into_iter().next().unwrap())
417 }
418 FetchMode::Optional => {
419 if results.len() > 1 {
420 return Err(DbkitError::RowCount {
421 expected: "0 or 1".into(),
422 actual: results.len(),
423 });
424 }
425 QueryResult::Optional(results.into_iter().next())
426 }
427 FetchMode::All => QueryResult::All(results),
428 };
429
430 Ok(ReadResult::Standard(query_result))
431 }
432
433 ReadOp::Arrow { query, params } => {
434 let query = query.to_string();
435 let params = params.clone();
436
437 let batches = task::spawn_blocking(move || {
438 let conn = duck_conn
439 .lock()
440 .map_err(|e| DbkitError::LockPoisoned(e.to_string()))?;
441 let mut stmt = conn
442 .prepare(&query)
443 .map_err(|e| DbkitError::DuckDb(e.to_string()))?;
444
445 let duck_values = Self::convert_params(¶ms);
446 let param_refs: Vec<&dyn duckdb::ToSql> =
447 duck_values.iter().map(|v| v as &dyn duckdb::ToSql).collect();
448
449 let arrow_iter = stmt
450 .query_arrow(param_refs.as_slice())
451 .map_err(|e| DbkitError::DuckDb(e.to_string()))?;
452
453 Ok::<Vec<RecordBatch>, DbkitError>(arrow_iter.collect())
454 })
455 .await
456 .map_err(|e| DbkitError::TaskJoin(e.to_string()))??;
457
458 Ok(ReadResult::Arrow(batches))
459 }
460 }
461 }
462
463 #[cfg(feature = "duckdb")]
469 pub async fn sync_tables(&self, tables: &[&str]) -> Result<(), DbkitError> {
470 let duck_conn = self
471 .duck_conn
472 .as_ref()
473 .ok_or(DbkitError::DuckDbNotInitialized)?
474 .clone();
475
476 let tables: Vec<String> = tables.iter().map(|t| t.to_string()).collect();
477
478 task::spawn_blocking(move || {
479 let conn = duck_conn
480 .lock()
481 .map_err(|e| DbkitError::LockPoisoned(e.to_string()))?;
482
483 for table in &tables {
484 let sql = format!(
485 "CREATE OR REPLACE TABLE memory.main.{table} AS SELECT * FROM pg.public.{table}"
486 );
487 conn.execute(&sql, [])
488 .map_err(|e| DbkitError::DuckDb(format!("sync {table}: {e}")))?;
489 }
490 Ok(())
491 })
492 .await
493 .map_err(|e| DbkitError::TaskJoin(e.to_string()))?
494 }
495
496 #[cfg(feature = "duckdb")]
501 pub async fn sync_table_filtered(
502 &self,
503 table: &str,
504 filter: &str,
505 params: &[DuckParam],
506 ) -> Result<(), DbkitError> {
507 let duck_conn = self
508 .duck_conn
509 .as_ref()
510 .ok_or(DbkitError::DuckDbNotInitialized)?
511 .clone();
512
513 let table = table.to_string();
514 let filter = filter.to_string();
515 let params = params.to_vec();
516
517 task::spawn_blocking(move || {
518 let conn = duck_conn
519 .lock()
520 .map_err(|e| DbkitError::LockPoisoned(e.to_string()))?;
521
522 let sql = format!(
523 "CREATE OR REPLACE TABLE memory.main.{table} AS SELECT * FROM pg.public.{table} WHERE {filter}"
524 );
525
526 let duck_values = Self::convert_params(¶ms);
527 let param_refs: Vec<&dyn duckdb::ToSql> =
528 duck_values.iter().map(|v| v as &dyn duckdb::ToSql).collect();
529
530 conn.execute(&sql, param_refs.as_slice())
531 .map_err(|e| DbkitError::DuckDb(format!("sync_filtered {table}: {e}")))?;
532
533 Ok(())
534 })
535 .await
536 .map_err(|e| DbkitError::TaskJoin(e.to_string()))?
537 }
538
539 #[cfg(feature = "duckdb")]
542 fn convert_params(params: &[DuckParam]) -> Vec<duckdb::types::Value> {
543 params
544 .iter()
545 .map(|p| match p {
546 DuckParam::Int(v) => duckdb::types::Value::Int(*v),
547 DuckParam::Int64(v) => duckdb::types::Value::BigInt(*v),
548 DuckParam::Float(v) => duckdb::types::Value::Double(*v),
549 DuckParam::Text(v) => duckdb::types::Value::Text(v.clone()),
550 DuckParam::Bool(v) => duckdb::types::Value::Boolean(*v),
551 DuckParam::Null => duckdb::types::Value::Null,
552 DuckParam::OptInt(v) => match v {
553 Some(val) => duckdb::types::Value::Int(*val),
554 None => duckdb::types::Value::Null,
555 },
556 DuckParam::OptInt64(v) => match v {
557 Some(val) => duckdb::types::Value::BigInt(*val),
558 None => duckdb::types::Value::Null,
559 },
560 DuckParam::OptFloat(v) => match v {
561 Some(val) => duckdb::types::Value::Double(*val),
562 None => duckdb::types::Value::Null,
563 },
564 DuckParam::OptText(v) => match v {
565 Some(val) => duckdb::types::Value::Text(val.clone()),
566 None => duckdb::types::Value::Null,
567 },
568 DuckParam::OptBool(v) => match v {
569 Some(val) => duckdb::types::Value::Boolean(*val),
570 None => duckdb::types::Value::Null,
571 },
572 })
573 .collect()
574 }
575}