1use crate::DbkitError;
2use deadpool_postgres::Pool;
3#[cfg(feature = "duckdb")]
4use duckdb::Connection as DuckConnection;
5use std::sync::Arc;
6#[cfg(feature = "duckdb")]
7use std::sync::Mutex;
8#[cfg(feature = "duckdb")]
9use tokio::task;
10use tokio_postgres::Row as PgRow;
11use tokio_postgres::types::ToSql;
12use tracing::warn;
13
14pub enum WriteOp<'a> {
20 Single {
22 query: &'a str,
23 params: &'a [&'a (dyn ToSql + Sync)],
24 mode: FetchMode,
25 },
26 BatchDDL { queries: &'a [&'a str] },
28 BatchParams {
30 query: &'a str,
31 params_list: Vec<Vec<Box<dyn ToSql + Sync + Send>>>,
32 },
33}
34
35#[cfg(feature = "duckdb")]
41pub enum ReadOp<'a, T, F>
42where
43 F: Fn(&duckdb::Row<'_>) -> Result<T, DbkitError> + Send + 'static,
44 T: Send + 'static,
45{
46 Standard {
48 query: &'a str,
49 params: Vec<DuckParam>,
50 map_fn: F,
51 mode: FetchMode,
52 },
53}
54
55#[cfg(feature = "duckdb")]
57#[derive(Debug, Clone)]
58pub enum DuckParam {
59 Int(i32),
60 Int64(i64),
61 Float(f64),
62 Text(String),
63 Bool(bool),
64 Null,
65}
66
67#[derive(Debug, Clone, Copy)]
73pub enum FetchMode {
74 None,
75 One,
76 Optional,
77 All,
78}
79
80pub enum QueryResult<T> {
82 None,
83 One(T),
84 Optional(Option<T>),
85 All(Vec<T>),
86}
87
88impl<T> QueryResult<T> {
89 pub fn one(self) -> Result<T, DbkitError> {
90 match self {
91 Self::One(v) => Ok(v),
92 _ => Err(DbkitError::RowCount {
93 expected: "One".into(),
94 actual: 0,
95 }),
96 }
97 }
98
99 pub fn optional(self) -> Result<Option<T>, DbkitError> {
100 match self {
101 Self::Optional(v) => Ok(v),
102 Self::One(v) => Ok(Some(v)),
103 Self::None => Ok(None),
104 _ => Err(DbkitError::RowCount {
105 expected: "Optional".into(),
106 actual: 0,
107 }),
108 }
109 }
110
111 pub fn all(self) -> Result<Vec<T>, DbkitError> {
112 match self {
113 Self::All(v) => Ok(v),
114 _ => Err(DbkitError::RowCount {
115 expected: "All".into(),
116 actual: 0,
117 }),
118 }
119 }
120}
121
122#[cfg(feature = "duckdb")]
124pub enum ReadResult<T> {
125 Standard(QueryResult<T>),
126}
127
128#[cfg(feature = "duckdb")]
129impl<T> ReadResult<T> {
130 pub fn standard(self) -> Result<QueryResult<T>, DbkitError> {
131 match self {
132 Self::Standard(qr) => Ok(qr),
133 }
134 }
135}
136
137pub struct BaseHandler {
143 pg_pool: Arc<Pool>,
144 #[cfg(feature = "duckdb")]
145 duck_conn: Option<Arc<Mutex<DuckConnection>>>,
146}
147
148impl BaseHandler {
149 pub fn new(pg_pool: Arc<Pool>) -> Self {
151 Self {
152 pg_pool,
153 #[cfg(feature = "duckdb")]
154 duck_conn: None,
155 }
156 }
157
158 #[cfg(feature = "duckdb")]
160 pub fn with_duckdb(
161 pg_pool: Arc<Pool>,
162 pg_connection_string: &str,
163 ) -> Result<Self, DbkitError> {
164 let duck_conn = DuckConnection::open_in_memory()
165 .map_err(|e| DbkitError::DuckDb(e.to_string()))?;
166
167 duck_conn
168 .execute_batch("INSTALL postgres; LOAD postgres;")
169 .map_err(|e| DbkitError::DuckDb(e.to_string()))?;
170
171 duck_conn
172 .execute(
173 &format!(
174 "ATTACH '{}' AS pg (TYPE POSTGRES)",
175 pg_connection_string
176 ),
177 [],
178 )
179 .map_err(|e| DbkitError::DuckDb(e.to_string()))?;
180
181 duck_conn
182 .execute("USE pg", [])
183 .map_err(|e| DbkitError::DuckDb(e.to_string()))?;
184
185 Ok(Self {
186 pg_pool,
187 duck_conn: Some(Arc::new(Mutex::new(duck_conn))),
188 })
189 }
190
191 pub fn has_duckdb(&self) -> bool {
193 #[cfg(feature = "duckdb")]
194 {
195 self.duck_conn.is_some()
196 }
197 #[cfg(not(feature = "duckdb"))]
198 {
199 false
200 }
201 }
202
203 pub fn pool(&self) -> &Arc<Pool> {
205 &self.pg_pool
206 }
207
208 pub async fn execute_write(
212 &self,
213 op: WriteOp<'_>,
214 ) -> Result<QueryResult<PgRow>, DbkitError> {
215 let mut client = self
216 .pg_pool
217 .get()
218 .await
219 .map_err(|e| DbkitError::Pool(e.to_string()))?;
220
221 match op {
222 WriteOp::Single {
223 query,
224 params,
225 mode,
226 } => match mode {
227 FetchMode::None => {
228 client.execute(query, params).await?;
229 Ok(QueryResult::None)
230 }
231 FetchMode::One => {
232 let row = client.query_one(query, params).await?;
233 Ok(QueryResult::One(row))
234 }
235 FetchMode::Optional => {
236 let row = client.query_opt(query, params).await?;
237 Ok(QueryResult::Optional(row))
238 }
239 FetchMode::All => {
240 let rows = client.query(query, params).await?;
241 Ok(QueryResult::All(rows))
242 }
243 },
244
245 WriteOp::BatchDDL { queries } => {
246 let transaction = client.transaction().await?;
247
248 for query in queries {
249 transaction.execute(*query, &[]).await?;
250 }
251
252 transaction.commit().await?;
253 Ok(QueryResult::None)
254 }
255
256 WriteOp::BatchParams {
257 query,
258 params_list,
259 } => {
260 if params_list.is_empty() {
261 return Ok(QueryResult::None);
262 }
263
264 let total = params_list.len();
265 let transaction = client.transaction().await?;
266 let stmt = transaction.prepare(query).await?;
267 let mut failed = 0usize;
268
269 let max_params = params_list.first().map(|p| p.len()).unwrap_or(0);
270 let mut params_refs: Vec<&(dyn ToSql + Sync)> =
271 Vec::with_capacity(max_params);
272
273 for (idx, params) in params_list.iter().enumerate() {
274 params_refs.clear();
275 params_refs
276 .extend(params.iter().map(|p| p.as_ref() as &(dyn ToSql + Sync)));
277 if let Err(e) = transaction.execute(&stmt, ¶ms_refs[..]).await {
278 warn!("BatchParams row {}/{} failed: {:?}", idx + 1, total, e);
279 failed += 1;
280 }
281 }
282
283 transaction.commit().await?;
284
285 if failed > 0 {
286 warn!(
287 "BatchParams: {}/{} succeeded, {} failed",
288 total - failed,
289 total,
290 failed
291 );
292 }
293
294 Ok(QueryResult::None)
295 }
296 }
297 }
298
299 #[cfg(feature = "duckdb")]
303 pub async fn execute_read<T, F>(
304 &self,
305 op: ReadOp<'_, T, F>,
306 ) -> Result<ReadResult<T>, DbkitError>
307 where
308 F: Fn(&duckdb::Row<'_>) -> Result<T, DbkitError> + Send + 'static,
309 T: Send + 'static,
310 {
311 let duck_conn = self
312 .duck_conn
313 .as_ref()
314 .ok_or(DbkitError::DuckDbNotInitialized)?
315 .clone();
316
317 match op {
318 ReadOp::Standard {
319 query,
320 params,
321 map_fn,
322 mode,
323 } => {
324 let query = query.to_string();
325 let params = params.clone();
326
327 let results = task::spawn_blocking(move || {
328 let conn = duck_conn
329 .lock()
330 .map_err(|e| DbkitError::LockPoisoned(e.to_string()))?;
331 let mut stmt = conn
332 .prepare(&query)
333 .map_err(|e| DbkitError::DuckDb(e.to_string()))?;
334
335 let duck_values = Self::convert_params(¶ms);
336 let param_refs: Vec<&dyn duckdb::ToSql> =
337 duck_values.iter().map(|v| v as &dyn duckdb::ToSql).collect();
338
339 let rows = stmt
340 .query_map(param_refs.as_slice(), |row| {
341 map_fn(row).map_err(|e| {
342 duckdb::Error::InvalidParameterName(e.to_string())
343 })
344 })
345 .map_err(|e| DbkitError::DuckDb(e.to_string()))?;
346
347 let mut results = Vec::new();
348 for row in rows {
349 results
350 .push(row.map_err(|e| DbkitError::DuckDb(e.to_string()))?);
351 }
352 Ok::<Vec<T>, DbkitError>(results)
353 })
354 .await
355 .map_err(|e| DbkitError::TaskJoin(e.to_string()))??;
356
357 let query_result = match mode {
358 FetchMode::None => QueryResult::None,
359 FetchMode::One => {
360 if results.len() != 1 {
361 return Err(DbkitError::RowCount {
362 expected: "1".into(),
363 actual: results.len(),
364 });
365 }
366 QueryResult::One(results.into_iter().next().unwrap())
367 }
368 FetchMode::Optional => {
369 if results.len() > 1 {
370 return Err(DbkitError::RowCount {
371 expected: "0 or 1".into(),
372 actual: results.len(),
373 });
374 }
375 QueryResult::Optional(results.into_iter().next())
376 }
377 FetchMode::All => QueryResult::All(results),
378 };
379
380 Ok(ReadResult::Standard(query_result))
381 }
382 }
383 }
384
385 #[cfg(feature = "duckdb")]
388 fn convert_params(params: &[DuckParam]) -> Vec<duckdb::types::Value> {
389 params
390 .iter()
391 .map(|p| match p {
392 DuckParam::Int(v) => duckdb::types::Value::Int(*v),
393 DuckParam::Int64(v) => duckdb::types::Value::BigInt(*v),
394 DuckParam::Float(v) => duckdb::types::Value::Double(*v),
395 DuckParam::Text(v) => duckdb::types::Value::Text(v.clone()),
396 DuckParam::Bool(v) => duckdb::types::Value::Boolean(*v),
397 DuckParam::Null => duckdb::types::Value::Null,
398 })
399 .collect()
400 }
401}