Skip to main content

qail_pg/driver/
ops.rs

1//! PgDriver operations: transaction control, batch execution, statement timeout,
2//! RLS context, pipeline, COPY bulk/export, and cursor streaming.
3
4use super::core::PgDriver;
5use super::pipeline::AstPipelineMode;
6use super::prepared::PreparedStatement;
7use super::rls;
8use super::types::*;
9use super::{AutoCountPath, AutoCountPlan};
10use qail_core::ast::Qail;
11use qail_core::transpiler::ToSql;
12
13impl PgDriver {
14    // ==================== TRANSACTION CONTROL ====================
15
16    /// Begin a transaction (AST-native).
17    pub async fn begin(&mut self) -> PgResult<()> {
18        self.connection.begin_transaction().await
19    }
20
21    /// Commit the current transaction (AST-native).
22    pub async fn commit(&mut self) -> PgResult<()> {
23        self.connection.commit().await
24    }
25
26    /// Rollback the current transaction (AST-native).
27    pub async fn rollback(&mut self) -> PgResult<()> {
28        self.connection.rollback().await
29    }
30
31    /// Create a named savepoint within the current transaction.
32    /// Savepoints allow partial rollback within a transaction.
33    /// Use `rollback_to()` to return to this savepoint.
34    /// # Example
35    /// ```ignore
36    /// driver.begin().await?;
37    /// driver.execute(&insert1).await?;
38    /// driver.savepoint("sp1").await?;
39    /// driver.execute(&insert2).await?;
40    /// driver.rollback_to("sp1").await?; // Undo insert2, keep insert1
41    /// driver.commit().await?;
42    /// ```
43    pub async fn savepoint(&mut self, name: &str) -> PgResult<()> {
44        self.connection.savepoint(name).await
45    }
46
47    /// Rollback to a previously created savepoint.
48    /// Discards all changes since the named savepoint was created,
49    /// but keeps the transaction open.
50    pub async fn rollback_to(&mut self, name: &str) -> PgResult<()> {
51        self.connection.rollback_to(name).await
52    }
53
54    /// Release a savepoint (free resources, if no longer needed).
55    /// After release, the savepoint cannot be rolled back to.
56    pub async fn release_savepoint(&mut self, name: &str) -> PgResult<()> {
57        self.connection.release_savepoint(name).await
58    }
59
60    // ==================== BATCH TRANSACTIONS ====================
61
62    /// Execute multiple commands in a single atomic transaction.
63    /// All commands succeed or all are rolled back.
64    /// # Example
65    /// ```ignore
66    /// let cmds = vec![
67    ///     Qail::add("users").columns(["name"]).values(["Alice"]),
68    ///     Qail::add("users").columns(["name"]).values(["Bob"]),
69    /// ];
70    /// let results = driver.execute_batch(&cmds).await?;
71    /// // results = [1, 1] (rows affected)
72    /// ```
73    pub async fn execute_batch(&mut self, cmds: &[Qail]) -> PgResult<Vec<u64>> {
74        self.begin().await?;
75        let mut results = Vec::with_capacity(cmds.len());
76        for cmd in cmds {
77            match self.execute(cmd).await {
78                Ok(n) => results.push(n),
79                Err(e) => {
80                    self.rollback().await?;
81                    return Err(e);
82                }
83            }
84        }
85        self.commit().await?;
86        Ok(results)
87    }
88
89    // ==================== STATEMENT TIMEOUT ====================
90
91    /// Set statement timeout for this connection (in milliseconds).
92    /// # Example
93    /// ```ignore
94    /// driver.set_statement_timeout(30_000).await?; // 30 seconds
95    /// ```
96    pub async fn set_statement_timeout(&mut self, ms: u32) -> PgResult<()> {
97        let cmd = Qail::session_set("statement_timeout", ms.to_string());
98        self.execute(&cmd).await.map(|_| ())
99    }
100
101    /// Reset statement timeout to default (no limit).
102    pub async fn reset_statement_timeout(&mut self) -> PgResult<()> {
103        let cmd = Qail::session_reset("statement_timeout");
104        self.execute(&cmd).await.map(|_| ())
105    }
106
107    /// Execute trusted administrative SQL using PostgreSQL's simple-query protocol.
108    ///
109    /// This is intended for internal/bootstrap DDL that cannot yet be expressed
110    /// by the QAIL AST, such as idempotent catalog maintenance statements.
111    pub async fn execute_simple(&mut self, sql: &str) -> PgResult<()> {
112        self.connection.execute_simple(sql).await
113    }
114
115    /// Execute trusted administrative SQL and return rows.
116    ///
117    /// This is the row-returning counterpart to `execute_simple`; prefer AST
118    /// APIs for application data access.
119    pub async fn simple_query(&mut self, sql: &str) -> PgResult<Vec<PgRow>> {
120        self.connection.simple_query(sql).await
121    }
122
123    // ==================== RLS (MULTI-TENANT) ====================
124
125    /// Set the RLS context for multi-tenant data isolation.
126    ///
127    /// Configures PostgreSQL session variables (`app.current_tenant_id`, etc.)
128    /// so that RLS policies automatically filter data by tenant.
129    ///
130    /// Since `PgDriver` takes `&mut self`, the borrow checker guarantees
131    /// that `set_config` and all subsequent queries execute on the **same
132    /// connection** — no pool race conditions possible.
133    ///
134    /// # Example
135    /// ```ignore
136    /// driver.set_rls_context(RlsContext::tenant("tenant-123")).await?;
137    /// let orders = driver.fetch_all(&Qail::get("orders")).await?;
138    /// // orders only contains rows for tenant-123
139    /// ```
140    pub async fn set_rls_context(&mut self, ctx: rls::RlsContext) -> PgResult<()> {
141        let sql = rls::context_to_sql(&ctx);
142        if sql.as_bytes().contains(&0) {
143            return Err(crate::PgError::Protocol(
144                "SQL contains NULL byte (0x00) which is invalid in PostgreSQL".to_string(),
145            ));
146        }
147        self.connection.execute_simple(&sql).await?;
148        self.rls_context = Some(ctx);
149        Ok(())
150    }
151
152    /// Clear the RLS context, resetting session variables to safe defaults.
153    ///
154    /// After clearing, all RLS-protected queries will return zero rows
155    /// (empty tenant scope matches nothing).
156    pub async fn clear_rls_context(&mut self) -> PgResult<()> {
157        let sql = rls::reset_sql();
158        if sql.as_bytes().contains(&0) {
159            return Err(crate::PgError::Protocol(
160                "SQL contains NULL byte (0x00) which is invalid in PostgreSQL".to_string(),
161            ));
162        }
163        self.connection.execute_simple(sql).await?;
164        self.rls_context = None;
165        Ok(())
166    }
167
168    /// Get the current RLS context, if any.
169    pub fn rls_context(&self) -> Option<&rls::RlsContext> {
170        self.rls_context.as_ref()
171    }
172
173    // ==================== PIPELINE (BATCH) ====================
174
175    /// Execute multiple Qail ASTs in a single network round-trip (PIPELINING).
176    /// # Example
177    /// ```ignore
178    /// let cmds: Vec<Qail> = (1..=1000)
179    ///     .map(|i| Qail::get("harbors").columns(["id", "name"]).limit(i))
180    ///     .collect();
181    /// let count = driver.pipeline_execute_count(&cmds).await?;
182    /// assert_eq!(count, 1000);
183    /// ```
184    pub async fn pipeline_execute_count(&mut self, cmds: &[Qail]) -> PgResult<usize> {
185        self.pipeline_execute_count_with_mode(cmds, AstPipelineMode::Auto)
186            .await
187    }
188
189    /// Execute commands with runtime auto strategy and return both count and plan.
190    ///
191    /// Strategy:
192    /// - `len <= 1`: single cached query path
193    /// - `2..8`: one-shot pipeline
194    /// - `>= 8`: cached pipeline
195    pub async fn execute_count_auto_with_plan(
196        &mut self,
197        cmds: &[Qail],
198    ) -> PgResult<(usize, AutoCountPlan)> {
199        let plan = AutoCountPlan::for_driver(cmds.len());
200
201        let completed = match plan.path {
202            AutoCountPath::SingleCached => {
203                if cmds.is_empty() {
204                    0
205                } else {
206                    let _ = self.fetch_all_cached(&cmds[0]).await?;
207                    1
208                }
209            }
210            AutoCountPath::PipelineOneShot => {
211                self.connection
212                    .pipeline_execute_count_ast_with_mode(cmds, AstPipelineMode::OneShot)
213                    .await?
214            }
215            AutoCountPath::PipelineCached => {
216                self.connection
217                    .pipeline_execute_count_ast_with_mode(cmds, AstPipelineMode::Cached)
218                    .await?
219            }
220            AutoCountPath::PoolParallel => {
221                return Err(PgError::Protocol(
222                    "driver auto planner returned pool-parallel path".to_string(),
223                ));
224            }
225        };
226
227        Ok((completed, plan))
228    }
229
230    /// Execute commands with runtime auto strategy.
231    #[inline]
232    pub async fn execute_count_auto(&mut self, cmds: &[Qail]) -> PgResult<usize> {
233        let (completed, _plan) = self.execute_count_auto_with_plan(cmds).await?;
234        Ok(completed)
235    }
236
237    /// Execute multiple Qail ASTs with an explicit pipeline strategy.
238    ///
239    /// Use [`AstPipelineMode::Cached`] for repeated templates in large batches,
240    /// or [`AstPipelineMode::OneShot`] for tiny one-off batches.
241    pub async fn pipeline_execute_count_with_mode(
242        &mut self,
243        cmds: &[Qail],
244        mode: AstPipelineMode,
245    ) -> PgResult<usize> {
246        self.connection
247            .pipeline_execute_count_ast_with_mode(cmds, mode)
248            .await
249    }
250
251    /// Execute multiple Qail ASTs and return full row data.
252    pub async fn pipeline_execute_rows(&mut self, cmds: &[Qail]) -> PgResult<Vec<Vec<PgRow>>> {
253        let raw_results = self.connection.pipeline_execute_rows_ast(cmds).await?;
254
255        let results: Vec<Vec<PgRow>> = raw_results
256            .into_iter()
257            .map(|rows| {
258                rows.into_iter()
259                    .map(|columns| PgRow {
260                        columns,
261                        column_info: None,
262                    })
263                    .collect()
264            })
265            .collect();
266
267        Ok(results)
268    }
269
270    /// Run `EXPLAIN (FORMAT JSON)` on a Qail AST command and return parsed estimates.
271    ///
272    /// Returns `Ok(None)` when PostgreSQL returns an unexpected JSON shape.
273    pub async fn explain_estimate(
274        &mut self,
275        cmd: &Qail,
276    ) -> PgResult<Option<crate::driver::explain::ExplainEstimate>> {
277        let explain_sql = format!("EXPLAIN (FORMAT JSON) {}", cmd.to_sql());
278        let rows = self.connection.simple_query(&explain_sql).await?;
279
280        let mut json_output = String::new();
281        for row in &rows {
282            if let Some(Some(val)) = row.columns.first()
283                && let Ok(text) = std::str::from_utf8(val)
284            {
285                json_output.push_str(text);
286            }
287        }
288
289        Ok(crate::driver::explain::parse_explain_json(&json_output))
290    }
291
292    /// Prepare a SQL statement for repeated execution.
293    pub async fn prepare(&mut self, sql: &str) -> PgResult<PreparedStatement> {
294        self.connection.prepare(sql).await
295    }
296
297    /// Execute a prepared statement pipeline in FAST mode (count only).
298    pub async fn pipeline_execute_prepared_count(
299        &mut self,
300        stmt: &PreparedStatement,
301        params_batch: &[Vec<Option<Vec<u8>>>],
302    ) -> PgResult<usize> {
303        self.connection
304            .pipeline_execute_prepared_count(stmt, params_batch)
305            .await
306    }
307
308    /// Bulk insert data using PostgreSQL COPY protocol (AST-native).
309    /// Uses a Qail::Add to get validated table and column names from the AST,
310    /// not user-provided strings. This is the sound, AST-native approach.
311    /// # Example
312    /// ```ignore
313    /// // Create a Qail::Add to define table and columns
314    /// let cmd = Qail::add("users")
315    ///     .columns(["id", "name", "email"]);
316    /// // Bulk insert rows
317    /// let rows: Vec<Vec<Value>> = vec![
318    ///     vec![Value::Int(1), Value::String("Alice"), Value::String("alice@ex.com")],
319    ///     vec![Value::Int(2), Value::String("Bob"), Value::String("bob@ex.com")],
320    /// ];
321    /// driver.copy_bulk(&cmd, &rows).await?;
322    /// ```
323    pub async fn copy_bulk(
324        &mut self,
325        cmd: &Qail,
326        rows: &[Vec<qail_core::ast::Value>],
327    ) -> PgResult<u64> {
328        use qail_core::ast::Action;
329
330        if cmd.action != Action::Add {
331            return Err(PgError::Query(
332                "copy_bulk requires Qail::Add action".to_string(),
333            ));
334        }
335
336        let table = &cmd.table;
337
338        let columns: Vec<String> = cmd
339            .columns
340            .iter()
341            .filter_map(|expr| {
342                use qail_core::ast::Expr;
343                match expr {
344                    Expr::Named(name) => Some(name.clone()),
345                    Expr::Aliased { name, .. } => Some(name.clone()),
346                    Expr::Star => None, // Can't COPY with *
347                    _ => None,
348                }
349            })
350            .collect();
351
352        if columns.is_empty() {
353            return Err(PgError::Query(
354                "copy_bulk requires columns in Qail".to_string(),
355            ));
356        }
357
358        // Use optimized COPY path: direct Value → bytes encoding, single syscall
359        self.connection.copy_in_fast(table, &columns, rows).await
360    }
361
362    /// **Fastest** bulk insert using pre-encoded COPY data.
363    /// Accepts raw COPY text format bytes. Use when caller has already
364    /// encoded rows to avoid any encoding overhead.
365    /// # Format
366    /// Data should be tab-separated rows with newlines (COPY text format):
367    /// `1\thello\t3.14\n2\tworld\t2.71\n`
368    /// # Example
369    /// ```ignore
370    /// let cmd = Qail::add("users").columns(["id", "name"]);
371    /// let data = b"1\tAlice\n2\tBob\n";
372    /// driver.copy_bulk_bytes(&cmd, data).await?;
373    /// ```
374    pub async fn copy_bulk_bytes(&mut self, cmd: &Qail, data: &[u8]) -> PgResult<u64> {
375        use qail_core::ast::Action;
376
377        if cmd.action != Action::Add {
378            return Err(PgError::Query(
379                "copy_bulk_bytes requires Qail::Add action".to_string(),
380            ));
381        }
382
383        let table = &cmd.table;
384        let columns: Vec<String> = cmd
385            .columns
386            .iter()
387            .filter_map(|expr| {
388                use qail_core::ast::Expr;
389                match expr {
390                    Expr::Named(name) => Some(name.clone()),
391                    Expr::Aliased { name, .. } => Some(name.clone()),
392                    _ => None,
393                }
394            })
395            .collect();
396
397        if columns.is_empty() {
398            return Err(PgError::Query(
399                "copy_bulk_bytes requires columns in Qail".to_string(),
400            ));
401        }
402
403        // Direct to raw COPY - zero encoding!
404        self.connection.copy_in_raw(table, &columns, data).await
405    }
406
407    /// Export table data using PostgreSQL COPY TO STDOUT (zero-copy streaming).
408    /// Returns rows as tab-separated bytes for direct re-import via copy_bulk_bytes.
409    /// # Example
410    /// ```ignore
411    /// let data = driver.copy_export_table("users", &["id", "name"]).await?;
412    /// shadow_driver.copy_bulk_bytes(&cmd, &data).await?;
413    /// ```
414    pub async fn copy_export_table(
415        &mut self,
416        table: &str,
417        columns: &[String],
418    ) -> PgResult<Vec<u8>> {
419        let cols: Vec<String> = columns
420            .iter()
421            .map(|c| super::copy::quote_copy_column_ident(c))
422            .collect::<PgResult<_>>()?;
423        let sql = format!(
424            "COPY {} ({}) TO STDOUT",
425            super::copy::quote_copy_table_ref(table)?,
426            cols.join(", ")
427        );
428
429        self.connection.copy_out_raw(&sql).await
430    }
431
432    /// Stream table export using COPY TO STDOUT with bounded memory usage.
433    ///
434    /// Chunks are forwarded directly from PostgreSQL to `on_chunk`.
435    pub async fn copy_export_table_stream<F, Fut>(
436        &mut self,
437        table: &str,
438        columns: &[String],
439        on_chunk: F,
440    ) -> PgResult<()>
441    where
442        F: FnMut(Vec<u8>) -> Fut,
443        Fut: std::future::Future<Output = PgResult<()>>,
444    {
445        let cols: Vec<String> = columns
446            .iter()
447            .map(|c| super::copy::quote_copy_column_ident(c))
448            .collect::<PgResult<_>>()?;
449        let sql = format!(
450            "COPY {} ({}) TO STDOUT",
451            super::copy::quote_copy_table_ref(table)?,
452            cols.join(", ")
453        );
454        self.connection.copy_out_raw_stream(&sql, on_chunk).await
455    }
456
457    /// Stream an AST-native `Qail::Export` command as raw COPY chunks.
458    pub async fn copy_export_cmd_stream<F, Fut>(&mut self, cmd: &Qail, on_chunk: F) -> PgResult<()>
459    where
460        F: FnMut(Vec<u8>) -> Fut,
461        Fut: std::future::Future<Output = PgResult<()>>,
462    {
463        self.connection.copy_export_stream_raw(cmd, on_chunk).await
464    }
465
466    /// Stream an AST-native `Qail::Export` command as parsed text rows.
467    pub async fn copy_export_cmd_stream_rows<F>(&mut self, cmd: &Qail, on_row: F) -> PgResult<()>
468    where
469        F: FnMut(Vec<String>) -> PgResult<()>,
470    {
471        self.connection.copy_export_stream_rows(cmd, on_row).await
472    }
473
474    /// Stream large result sets using PostgreSQL cursors.
475    /// This method uses DECLARE CURSOR internally to stream rows in batches,
476    /// avoiding loading the entire result set into memory.
477    /// # Example
478    /// ```ignore
479    /// let cmd = Qail::get("large_table");
480    /// let batches = driver.stream_cmd(&cmd, 100).await?;
481    /// for batch in batches {
482    ///     for row in batch {
483    ///         // process row
484    ///     }
485    /// }
486    /// ```
487    pub async fn stream_cmd(&mut self, cmd: &Qail, batch_size: usize) -> PgResult<Vec<Vec<PgRow>>> {
488        use std::sync::atomic::{AtomicU64, Ordering};
489        static CURSOR_ID: AtomicU64 = AtomicU64::new(0);
490
491        let cursor_name = format!("qail_cursor_{}", CURSOR_ID.fetch_add(1, Ordering::SeqCst));
492
493        // AST-NATIVE: Generate SQL directly from AST (no to_sql_parameterized!)
494        use crate::protocol::AstEncoder;
495        let mut sql_buf = bytes::BytesMut::with_capacity(256);
496        let mut params: Vec<Option<Vec<u8>>> = Vec::new();
497        AstEncoder::encode_select_sql(cmd, &mut sql_buf, &mut params)
498            .map_err(|e| PgError::Encode(e.to_string()))?;
499        let sql = std::str::from_utf8(&sql_buf)
500            .map_err(|e| PgError::Encode(format!("encoded SQL is not UTF-8: {}", e)))?
501            .to_string();
502
503        // Must be in a transaction for cursors
504        self.connection.begin_transaction().await?;
505
506        let stream_result = async {
507            // Declare cursor with bind params — Extended Query Protocol handles $1, $2 etc.
508            self.connection
509                .declare_cursor(&cursor_name, &sql, &params)
510                .await?;
511
512            // Fetch all batches
513            let mut all_batches = Vec::new();
514            while let Some(rows) = self
515                .connection
516                .fetch_cursor(&cursor_name, batch_size)
517                .await?
518            {
519                let pg_rows: Vec<PgRow> = rows
520                    .into_iter()
521                    .map(|cols| PgRow {
522                        columns: cols,
523                        column_info: None,
524                    })
525                    .collect();
526                all_batches.push(pg_rows);
527            }
528
529            self.connection.close_cursor(&cursor_name).await?;
530            Ok(all_batches)
531        }
532        .await;
533
534        match stream_result {
535            Ok(all_batches) => {
536                self.connection.commit().await?;
537                Ok(all_batches)
538            }
539            Err(err) => {
540                if self.connection.rollback().await.is_err() {
541                    self.connection.mark_io_desynced();
542                }
543                Err(err)
544            }
545        }
546    }
547}