Skip to main content

qail_pg/driver/
ops.rs

1//! PgDriver operations: transaction control, batch execution, statement timeout,
2//! RLS context, pipeline, COPY bulk/export, and cursor streaming.
3
4use super::core::PgDriver;
5use super::pipeline::AstPipelineMode;
6use super::prepared::PreparedStatement;
7use super::rls;
8use super::types::*;
9use super::{AutoCountPath, AutoCountPlan};
10use qail_core::ast::Qail;
11
12impl PgDriver {
13    // ==================== RAW SQL COMPAT ====================
14
15    /// Execute raw SQL using PostgreSQL Simple Query protocol.
16    ///
17    /// Returns number of rows returned by the statement (or `0` for statements
18    /// that do not produce rows).
19    ///
20    /// This compatibility API exists for legacy tests/examples; prefer AST APIs
21    /// (`execute`, `fetch_all`) for application code.
22    pub async fn execute_raw(&mut self, sql: &str) -> PgResult<u64> {
23        let rows = self.connection.simple_query(sql).await?;
24        Ok(rows.len() as u64)
25    }
26
27    /// Execute raw SQL using PostgreSQL Simple Query protocol and return rows.
28    ///
29    /// This compatibility API exists for legacy tests/examples; prefer AST APIs
30    /// (`execute`, `fetch_all`) for application code.
31    pub async fn fetch_raw(&mut self, sql: &str) -> PgResult<Vec<PgRow>> {
32        self.connection.simple_query(sql).await
33    }
34
35    // ==================== TRANSACTION CONTROL ====================
36
37    /// Begin a transaction (AST-native).
38    pub async fn begin(&mut self) -> PgResult<()> {
39        self.connection.begin_transaction().await
40    }
41
42    /// Commit the current transaction (AST-native).
43    pub async fn commit(&mut self) -> PgResult<()> {
44        self.connection.commit().await
45    }
46
47    /// Rollback the current transaction (AST-native).
48    pub async fn rollback(&mut self) -> PgResult<()> {
49        self.connection.rollback().await
50    }
51
52    /// Create a named savepoint within the current transaction.
53    /// Savepoints allow partial rollback within a transaction.
54    /// Use `rollback_to()` to return to this savepoint.
55    /// # Example
56    /// ```ignore
57    /// driver.begin().await?;
58    /// driver.execute(&insert1).await?;
59    /// driver.savepoint("sp1").await?;
60    /// driver.execute(&insert2).await?;
61    /// driver.rollback_to("sp1").await?; // Undo insert2, keep insert1
62    /// driver.commit().await?;
63    /// ```
64    pub async fn savepoint(&mut self, name: &str) -> PgResult<()> {
65        self.connection.savepoint(name).await
66    }
67
68    /// Rollback to a previously created savepoint.
69    /// Discards all changes since the named savepoint was created,
70    /// but keeps the transaction open.
71    pub async fn rollback_to(&mut self, name: &str) -> PgResult<()> {
72        self.connection.rollback_to(name).await
73    }
74
75    /// Release a savepoint (free resources, if no longer needed).
76    /// After release, the savepoint cannot be rolled back to.
77    pub async fn release_savepoint(&mut self, name: &str) -> PgResult<()> {
78        self.connection.release_savepoint(name).await
79    }
80
81    // ==================== BATCH TRANSACTIONS ====================
82
83    /// Execute multiple commands in a single atomic transaction.
84    /// All commands succeed or all are rolled back.
85    /// # Example
86    /// ```ignore
87    /// let cmds = vec![
88    ///     Qail::add("users").columns(["name"]).values(["Alice"]),
89    ///     Qail::add("users").columns(["name"]).values(["Bob"]),
90    /// ];
91    /// let results = driver.execute_batch(&cmds).await?;
92    /// // results = [1, 1] (rows affected)
93    /// ```
94    pub async fn execute_batch(&mut self, cmds: &[Qail]) -> PgResult<Vec<u64>> {
95        self.begin().await?;
96        let mut results = Vec::with_capacity(cmds.len());
97        for cmd in cmds {
98            match self.execute(cmd).await {
99                Ok(n) => results.push(n),
100                Err(e) => {
101                    self.rollback().await?;
102                    return Err(e);
103                }
104            }
105        }
106        self.commit().await?;
107        Ok(results)
108    }
109
110    // ==================== STATEMENT TIMEOUT ====================
111
112    /// Set statement timeout for this connection (in milliseconds).
113    /// # Example
114    /// ```ignore
115    /// driver.set_statement_timeout(30_000).await?; // 30 seconds
116    /// ```
117    pub async fn set_statement_timeout(&mut self, ms: u32) -> PgResult<()> {
118        let cmd = Qail::session_set("statement_timeout", ms.to_string());
119        self.execute(&cmd).await.map(|_| ())
120    }
121
122    /// Reset statement timeout to default (no limit).
123    pub async fn reset_statement_timeout(&mut self) -> PgResult<()> {
124        let cmd = Qail::session_reset("statement_timeout");
125        self.execute(&cmd).await.map(|_| ())
126    }
127
128    // ==================== RLS (MULTI-TENANT) ====================
129
130    /// Set the RLS context for multi-tenant data isolation.
131    ///
132    /// Configures PostgreSQL session variables (`app.current_tenant_id`, etc.)
133    /// so that RLS policies automatically filter data by tenant.
134    ///
135    /// Since `PgDriver` takes `&mut self`, the borrow checker guarantees
136    /// that `set_config` and all subsequent queries execute on the **same
137    /// connection** — no pool race conditions possible.
138    ///
139    /// # Example
140    /// ```ignore
141    /// driver.set_rls_context(RlsContext::tenant("tenant-123")).await?;
142    /// let orders = driver.fetch_all(&Qail::get("orders")).await?;
143    /// // orders only contains rows for tenant-123
144    /// ```
145    pub async fn set_rls_context(&mut self, ctx: rls::RlsContext) -> PgResult<()> {
146        let sql = rls::context_to_sql(&ctx);
147        if sql.as_bytes().contains(&0) {
148            return Err(crate::PgError::Protocol(
149                "SQL contains NULL byte (0x00) which is invalid in PostgreSQL".to_string(),
150            ));
151        }
152        self.connection.execute_simple(&sql).await?;
153        self.rls_context = Some(ctx);
154        Ok(())
155    }
156
157    /// Clear the RLS context, resetting session variables to safe defaults.
158    ///
159    /// After clearing, all RLS-protected queries will return zero rows
160    /// (empty tenant scope matches nothing).
161    pub async fn clear_rls_context(&mut self) -> PgResult<()> {
162        let sql = rls::reset_sql();
163        if sql.as_bytes().contains(&0) {
164            return Err(crate::PgError::Protocol(
165                "SQL contains NULL byte (0x00) which is invalid in PostgreSQL".to_string(),
166            ));
167        }
168        self.connection.execute_simple(sql).await?;
169        self.rls_context = None;
170        Ok(())
171    }
172
173    /// Get the current RLS context, if any.
174    pub fn rls_context(&self) -> Option<&rls::RlsContext> {
175        self.rls_context.as_ref()
176    }
177
178    // ==================== PIPELINE (BATCH) ====================
179
180    /// Execute multiple Qail ASTs in a single network round-trip (PIPELINING).
181    /// # Example
182    /// ```ignore
183    /// let cmds: Vec<Qail> = (1..=1000)
184    ///     .map(|i| Qail::get("harbors").columns(["id", "name"]).limit(i))
185    ///     .collect();
186    /// let count = driver.pipeline_execute_count(&cmds).await?;
187    /// assert_eq!(count, 1000);
188    /// ```
189    pub async fn pipeline_execute_count(&mut self, cmds: &[Qail]) -> PgResult<usize> {
190        self.pipeline_execute_count_with_mode(cmds, AstPipelineMode::Auto)
191            .await
192    }
193
194    /// Execute commands with runtime auto strategy and return both count and plan.
195    ///
196    /// Strategy:
197    /// - `len <= 1`: single cached query path
198    /// - `2..8`: one-shot pipeline
199    /// - `>= 8`: cached pipeline
200    pub async fn execute_count_auto_with_plan(
201        &mut self,
202        cmds: &[Qail],
203    ) -> PgResult<(usize, AutoCountPlan)> {
204        let plan = AutoCountPlan::for_driver(cmds.len());
205
206        let completed = match plan.path {
207            AutoCountPath::SingleCached => {
208                if cmds.is_empty() {
209                    0
210                } else {
211                    let _ = self.fetch_all_cached(&cmds[0]).await?;
212                    1
213                }
214            }
215            AutoCountPath::PipelineOneShot => {
216                self.connection
217                    .pipeline_execute_count_ast_with_mode(cmds, AstPipelineMode::OneShot)
218                    .await?
219            }
220            AutoCountPath::PipelineCached => {
221                self.connection
222                    .pipeline_execute_count_ast_with_mode(cmds, AstPipelineMode::Cached)
223                    .await?
224            }
225            AutoCountPath::PoolParallel => {
226                unreachable!("driver auto planner cannot resolve pool-parallel")
227            }
228        };
229
230        Ok((completed, plan))
231    }
232
233    /// Execute commands with runtime auto strategy.
234    #[inline]
235    pub async fn execute_count_auto(&mut self, cmds: &[Qail]) -> PgResult<usize> {
236        let (completed, _plan) = self.execute_count_auto_with_plan(cmds).await?;
237        Ok(completed)
238    }
239
240    /// Execute multiple Qail ASTs with an explicit pipeline strategy.
241    ///
242    /// Use [`AstPipelineMode::Cached`] for repeated templates in large batches,
243    /// or [`AstPipelineMode::OneShot`] for tiny one-off batches.
244    pub async fn pipeline_execute_count_with_mode(
245        &mut self,
246        cmds: &[Qail],
247        mode: AstPipelineMode,
248    ) -> PgResult<usize> {
249        self.connection
250            .pipeline_execute_count_ast_with_mode(cmds, mode)
251            .await
252    }
253
254    /// Execute multiple Qail ASTs and return full row data.
255    pub async fn pipeline_execute_rows(&mut self, cmds: &[Qail]) -> PgResult<Vec<Vec<PgRow>>> {
256        let raw_results = self.connection.pipeline_execute_rows_ast(cmds).await?;
257
258        let results: Vec<Vec<PgRow>> = raw_results
259            .into_iter()
260            .map(|rows| {
261                rows.into_iter()
262                    .map(|columns| PgRow {
263                        columns,
264                        column_info: None,
265                    })
266                    .collect()
267            })
268            .collect();
269
270        Ok(results)
271    }
272
273    /// Prepare a SQL statement for repeated execution.
274    pub async fn prepare(&mut self, sql: &str) -> PgResult<PreparedStatement> {
275        self.connection.prepare(sql).await
276    }
277
278    /// Execute a prepared statement pipeline in FAST mode (count only).
279    pub async fn pipeline_execute_prepared_count(
280        &mut self,
281        stmt: &PreparedStatement,
282        params_batch: &[Vec<Option<Vec<u8>>>],
283    ) -> PgResult<usize> {
284        self.connection
285            .pipeline_execute_prepared_count(stmt, params_batch)
286            .await
287    }
288
289    /// Bulk insert data using PostgreSQL COPY protocol (AST-native).
290    /// Uses a Qail::Add to get validated table and column names from the AST,
291    /// not user-provided strings. This is the sound, AST-native approach.
292    /// # Example
293    /// ```ignore
294    /// // Create a Qail::Add to define table and columns
295    /// let cmd = Qail::add("users")
296    ///     .columns(["id", "name", "email"]);
297    /// // Bulk insert rows
298    /// let rows: Vec<Vec<Value>> = vec![
299    ///     vec![Value::Int(1), Value::String("Alice"), Value::String("alice@ex.com")],
300    ///     vec![Value::Int(2), Value::String("Bob"), Value::String("bob@ex.com")],
301    /// ];
302    /// driver.copy_bulk(&cmd, &rows).await?;
303    /// ```
304    pub async fn copy_bulk(
305        &mut self,
306        cmd: &Qail,
307        rows: &[Vec<qail_core::ast::Value>],
308    ) -> PgResult<u64> {
309        use qail_core::ast::Action;
310
311        if cmd.action != Action::Add {
312            return Err(PgError::Query(
313                "copy_bulk requires Qail::Add action".to_string(),
314            ));
315        }
316
317        let table = &cmd.table;
318
319        let columns: Vec<String> = cmd
320            .columns
321            .iter()
322            .filter_map(|expr| {
323                use qail_core::ast::Expr;
324                match expr {
325                    Expr::Named(name) => Some(name.clone()),
326                    Expr::Aliased { name, .. } => Some(name.clone()),
327                    Expr::Star => None, // Can't COPY with *
328                    _ => None,
329                }
330            })
331            .collect();
332
333        if columns.is_empty() {
334            return Err(PgError::Query(
335                "copy_bulk requires columns in Qail".to_string(),
336            ));
337        }
338
339        // Use optimized COPY path: direct Value → bytes encoding, single syscall
340        self.connection.copy_in_fast(table, &columns, rows).await
341    }
342
343    /// **Fastest** bulk insert using pre-encoded COPY data.
344    /// Accepts raw COPY text format bytes. Use when caller has already
345    /// encoded rows to avoid any encoding overhead.
346    /// # Format
347    /// Data should be tab-separated rows with newlines (COPY text format):
348    /// `1\thello\t3.14\n2\tworld\t2.71\n`
349    /// # Example
350    /// ```ignore
351    /// let cmd = Qail::add("users").columns(["id", "name"]);
352    /// let data = b"1\tAlice\n2\tBob\n";
353    /// driver.copy_bulk_bytes(&cmd, data).await?;
354    /// ```
355    pub async fn copy_bulk_bytes(&mut self, cmd: &Qail, data: &[u8]) -> PgResult<u64> {
356        use qail_core::ast::Action;
357
358        if cmd.action != Action::Add {
359            return Err(PgError::Query(
360                "copy_bulk_bytes requires Qail::Add action".to_string(),
361            ));
362        }
363
364        let table = &cmd.table;
365        let columns: Vec<String> = cmd
366            .columns
367            .iter()
368            .filter_map(|expr| {
369                use qail_core::ast::Expr;
370                match expr {
371                    Expr::Named(name) => Some(name.clone()),
372                    Expr::Aliased { name, .. } => Some(name.clone()),
373                    _ => None,
374                }
375            })
376            .collect();
377
378        if columns.is_empty() {
379            return Err(PgError::Query(
380                "copy_bulk_bytes requires columns in Qail".to_string(),
381            ));
382        }
383
384        // Direct to raw COPY - zero encoding!
385        self.connection.copy_in_raw(table, &columns, data).await
386    }
387
388    /// Export table data using PostgreSQL COPY TO STDOUT (zero-copy streaming).
389    /// Returns rows as tab-separated bytes for direct re-import via copy_bulk_bytes.
390    /// # Example
391    /// ```ignore
392    /// let data = driver.copy_export_table("users", &["id", "name"]).await?;
393    /// shadow_driver.copy_bulk_bytes(&cmd, &data).await?;
394    /// ```
395    pub async fn copy_export_table(
396        &mut self,
397        table: &str,
398        columns: &[String],
399    ) -> PgResult<Vec<u8>> {
400        let quote_ident = |ident: &str| -> String {
401            format!("\"{}\"", ident.replace('\0', "").replace('"', "\"\""))
402        };
403        let cols: Vec<String> = columns.iter().map(|c| quote_ident(c)).collect();
404        let sql = format!(
405            "COPY {} ({}) TO STDOUT",
406            quote_ident(table),
407            cols.join(", ")
408        );
409
410        self.connection.copy_out_raw(&sql).await
411    }
412
413    /// Stream table export using COPY TO STDOUT with bounded memory usage.
414    ///
415    /// Chunks are forwarded directly from PostgreSQL to `on_chunk`.
416    pub async fn copy_export_table_stream<F, Fut>(
417        &mut self,
418        table: &str,
419        columns: &[String],
420        on_chunk: F,
421    ) -> PgResult<()>
422    where
423        F: FnMut(Vec<u8>) -> Fut,
424        Fut: std::future::Future<Output = PgResult<()>>,
425    {
426        let quote_ident = |ident: &str| -> String {
427            format!("\"{}\"", ident.replace('\0', "").replace('"', "\"\""))
428        };
429        let cols: Vec<String> = columns.iter().map(|c| quote_ident(c)).collect();
430        let sql = format!(
431            "COPY {} ({}) TO STDOUT",
432            quote_ident(table),
433            cols.join(", ")
434        );
435        self.connection.copy_out_raw_stream(&sql, on_chunk).await
436    }
437
438    /// Stream an AST-native `Qail::Export` command as raw COPY chunks.
439    pub async fn copy_export_cmd_stream<F, Fut>(&mut self, cmd: &Qail, on_chunk: F) -> PgResult<()>
440    where
441        F: FnMut(Vec<u8>) -> Fut,
442        Fut: std::future::Future<Output = PgResult<()>>,
443    {
444        self.connection.copy_export_stream_raw(cmd, on_chunk).await
445    }
446
447    /// Stream an AST-native `Qail::Export` command as parsed text rows.
448    pub async fn copy_export_cmd_stream_rows<F>(&mut self, cmd: &Qail, on_row: F) -> PgResult<()>
449    where
450        F: FnMut(Vec<String>) -> PgResult<()>,
451    {
452        self.connection.copy_export_stream_rows(cmd, on_row).await
453    }
454
455    /// Stream large result sets using PostgreSQL cursors.
456    /// This method uses DECLARE CURSOR internally to stream rows in batches,
457    /// avoiding loading the entire result set into memory.
458    /// # Example
459    /// ```ignore
460    /// let cmd = Qail::get("large_table");
461    /// let batches = driver.stream_cmd(&cmd, 100).await?;
462    /// for batch in batches {
463    ///     for row in batch {
464    ///         // process row
465    ///     }
466    /// }
467    /// ```
468    pub async fn stream_cmd(&mut self, cmd: &Qail, batch_size: usize) -> PgResult<Vec<Vec<PgRow>>> {
469        use std::sync::atomic::{AtomicU64, Ordering};
470        static CURSOR_ID: AtomicU64 = AtomicU64::new(0);
471
472        let cursor_name = format!("qail_cursor_{}", CURSOR_ID.fetch_add(1, Ordering::SeqCst));
473
474        // AST-NATIVE: Generate SQL directly from AST (no to_sql_parameterized!)
475        use crate::protocol::AstEncoder;
476        let mut sql_buf = bytes::BytesMut::with_capacity(256);
477        let mut params: Vec<Option<Vec<u8>>> = Vec::new();
478        AstEncoder::encode_select_sql(cmd, &mut sql_buf, &mut params)
479            .map_err(|e| PgError::Encode(e.to_string()))?;
480        let sql = String::from_utf8_lossy(&sql_buf).to_string();
481
482        // Must be in a transaction for cursors
483        self.connection.begin_transaction().await?;
484
485        // Declare cursor
486        // Declare cursor with bind params — Extended Query Protocol handles $1, $2 etc.
487        self.connection
488            .declare_cursor(&cursor_name, &sql, &params)
489            .await?;
490
491        // Fetch all batches
492        let mut all_batches = Vec::new();
493        while let Some(rows) = self
494            .connection
495            .fetch_cursor(&cursor_name, batch_size)
496            .await?
497        {
498            let pg_rows: Vec<PgRow> = rows
499                .into_iter()
500                .map(|cols| PgRow {
501                    columns: cols,
502                    column_info: None,
503                })
504                .collect();
505            all_batches.push(pg_rows);
506        }
507
508        self.connection.close_cursor(&cursor_name).await?;
509        self.connection.commit().await?;
510
511        Ok(all_batches)
512    }
513}