Skip to main content

qail_pg/driver/
ops.rs

1//! PgDriver operations: transaction control, batch execution, statement timeout,
2//! RLS context, pipeline, COPY bulk/export, and cursor streaming.
3
4use super::core::PgDriver;
5use super::prepared::PreparedStatement;
6use super::rls;
7use super::types::*;
8use qail_core::ast::Qail;
9
10impl PgDriver {
11    // ==================== TRANSACTION CONTROL ====================
12
13    /// Begin a transaction (AST-native).
14    pub async fn begin(&mut self) -> PgResult<()> {
15        self.connection.begin_transaction().await
16    }
17
18    /// Commit the current transaction (AST-native).
19    pub async fn commit(&mut self) -> PgResult<()> {
20        self.connection.commit().await
21    }
22
23    /// Rollback the current transaction (AST-native).
24    pub async fn rollback(&mut self) -> PgResult<()> {
25        self.connection.rollback().await
26    }
27
28    /// Create a named savepoint within the current transaction.
29    /// Savepoints allow partial rollback within a transaction.
30    /// Use `rollback_to()` to return to this savepoint.
31    /// # Example
32    /// ```ignore
33    /// driver.begin().await?;
34    /// driver.execute(&insert1).await?;
35    /// driver.savepoint("sp1").await?;
36    /// driver.execute(&insert2).await?;
37    /// driver.rollback_to("sp1").await?; // Undo insert2, keep insert1
38    /// driver.commit().await?;
39    /// ```
40    pub async fn savepoint(&mut self, name: &str) -> PgResult<()> {
41        self.connection.savepoint(name).await
42    }
43
44    /// Rollback to a previously created savepoint.
45    /// Discards all changes since the named savepoint was created,
46    /// but keeps the transaction open.
47    pub async fn rollback_to(&mut self, name: &str) -> PgResult<()> {
48        self.connection.rollback_to(name).await
49    }
50
51    /// Release a savepoint (free resources, if no longer needed).
52    /// After release, the savepoint cannot be rolled back to.
53    pub async fn release_savepoint(&mut self, name: &str) -> PgResult<()> {
54        self.connection.release_savepoint(name).await
55    }
56
57    // ==================== BATCH TRANSACTIONS ====================
58
59    /// Execute multiple commands in a single atomic transaction.
60    /// All commands succeed or all are rolled back.
61    /// # Example
62    /// ```ignore
63    /// let cmds = vec![
64    ///     Qail::add("users").columns(["name"]).values(["Alice"]),
65    ///     Qail::add("users").columns(["name"]).values(["Bob"]),
66    /// ];
67    /// let results = driver.execute_batch(&cmds).await?;
68    /// // results = [1, 1] (rows affected)
69    /// ```
70    pub async fn execute_batch(&mut self, cmds: &[Qail]) -> PgResult<Vec<u64>> {
71        self.begin().await?;
72        let mut results = Vec::with_capacity(cmds.len());
73        for cmd in cmds {
74            match self.execute(cmd).await {
75                Ok(n) => results.push(n),
76                Err(e) => {
77                    self.rollback().await?;
78                    return Err(e);
79                }
80            }
81        }
82        self.commit().await?;
83        Ok(results)
84    }
85
86    // ==================== STATEMENT TIMEOUT ====================
87
88    /// Set statement timeout for this connection (in milliseconds).
89    /// # Example
90    /// ```ignore
91    /// driver.set_statement_timeout(30_000).await?; // 30 seconds
92    /// ```
93    pub async fn set_statement_timeout(&mut self, ms: u32) -> PgResult<()> {
94        let cmd = Qail::session_set("statement_timeout", ms.to_string());
95        self.execute(&cmd).await.map(|_| ())
96    }
97
98    /// Reset statement timeout to default (no limit).
99    pub async fn reset_statement_timeout(&mut self) -> PgResult<()> {
100        let cmd = Qail::session_reset("statement_timeout");
101        self.execute(&cmd).await.map(|_| ())
102    }
103
104    // ==================== RLS (MULTI-TENANT) ====================
105
106    /// Set the RLS context for multi-tenant data isolation.
107    ///
108    /// Configures PostgreSQL session variables (`app.current_tenant_id`, etc.)
109    /// so that RLS policies automatically filter data by tenant.
110    ///
111    /// Since `PgDriver` takes `&mut self`, the borrow checker guarantees
112    /// that `set_config` and all subsequent queries execute on the **same
113    /// connection** — no pool race conditions possible.
114    ///
115    /// # Example
116    /// ```ignore
117    /// driver.set_rls_context(RlsContext::tenant("tenant-123")).await?;
118    /// let orders = driver.fetch_all(&Qail::get("orders")).await?;
119    /// // orders only contains rows for tenant-123
120    /// ```
121    pub async fn set_rls_context(&mut self, ctx: rls::RlsContext) -> PgResult<()> {
122        let sql = rls::context_to_sql(&ctx);
123        if sql.as_bytes().contains(&0) {
124            return Err(crate::PgError::Protocol(
125                "SQL contains NULL byte (0x00) which is invalid in PostgreSQL".to_string(),
126            ));
127        }
128        self.connection.execute_simple(&sql).await?;
129        self.rls_context = Some(ctx);
130        Ok(())
131    }
132
133    /// Clear the RLS context, resetting session variables to safe defaults.
134    ///
135    /// After clearing, all RLS-protected queries will return zero rows
136    /// (empty tenant scope matches nothing).
137    pub async fn clear_rls_context(&mut self) -> PgResult<()> {
138        let sql = rls::reset_sql();
139        if sql.as_bytes().contains(&0) {
140            return Err(crate::PgError::Protocol(
141                "SQL contains NULL byte (0x00) which is invalid in PostgreSQL".to_string(),
142            ));
143        }
144        self.connection.execute_simple(sql).await?;
145        self.rls_context = None;
146        Ok(())
147    }
148
149    /// Get the current RLS context, if any.
150    pub fn rls_context(&self) -> Option<&rls::RlsContext> {
151        self.rls_context.as_ref()
152    }
153
154    // ==================== PIPELINE (BATCH) ====================
155
156    /// Execute multiple Qail ASTs in a single network round-trip (PIPELINING).
157    /// # Example
158    /// ```ignore
159    /// let cmds: Vec<Qail> = (1..=1000)
160    ///     .map(|i| Qail::get("harbors").columns(["id", "name"]).limit(i))
161    ///     .collect();
162    /// let count = driver.pipeline_batch(&cmds).await?;
163    /// assert_eq!(count, 1000);
164    /// ```
165    pub async fn pipeline_batch(&mut self, cmds: &[Qail]) -> PgResult<usize> {
166        self.connection.pipeline_ast_fast(cmds).await
167    }
168
169    /// Execute multiple Qail ASTs and return full row data.
170    pub async fn pipeline_fetch(&mut self, cmds: &[Qail]) -> PgResult<Vec<Vec<PgRow>>> {
171        let raw_results = self.connection.pipeline_ast(cmds).await?;
172
173        let results: Vec<Vec<PgRow>> = raw_results
174            .into_iter()
175            .map(|rows| {
176                rows.into_iter()
177                    .map(|columns| PgRow {
178                        columns,
179                        column_info: None,
180                    })
181                    .collect()
182            })
183            .collect();
184
185        Ok(results)
186    }
187
188    /// Prepare a SQL statement for repeated execution.
189    pub async fn prepare(&mut self, sql: &str) -> PgResult<PreparedStatement> {
190        self.connection.prepare(sql).await
191    }
192
193    /// Execute a prepared statement pipeline in FAST mode (count only).
194    pub async fn pipeline_prepared_fast(
195        &mut self,
196        stmt: &PreparedStatement,
197        params_batch: &[Vec<Option<Vec<u8>>>],
198    ) -> PgResult<usize> {
199        self.connection
200            .pipeline_prepared_fast(stmt, params_batch)
201            .await
202    }
203
204    /// Bulk insert data using PostgreSQL COPY protocol (AST-native).
205    /// Uses a Qail::Add to get validated table and column names from the AST,
206    /// not user-provided strings. This is the sound, AST-native approach.
207    /// # Example
208    /// ```ignore
209    /// // Create a Qail::Add to define table and columns
210    /// let cmd = Qail::add("users")
211    ///     .columns(["id", "name", "email"]);
212    /// // Bulk insert rows
213    /// let rows: Vec<Vec<Value>> = vec![
214    ///     vec![Value::Int(1), Value::String("Alice"), Value::String("alice@ex.com")],
215    ///     vec![Value::Int(2), Value::String("Bob"), Value::String("bob@ex.com")],
216    /// ];
217    /// driver.copy_bulk(&cmd, &rows).await?;
218    /// ```
219    pub async fn copy_bulk(
220        &mut self,
221        cmd: &Qail,
222        rows: &[Vec<qail_core::ast::Value>],
223    ) -> PgResult<u64> {
224        use qail_core::ast::Action;
225
226        if cmd.action != Action::Add {
227            return Err(PgError::Query(
228                "copy_bulk requires Qail::Add action".to_string(),
229            ));
230        }
231
232        let table = &cmd.table;
233
234        let columns: Vec<String> = cmd
235            .columns
236            .iter()
237            .filter_map(|expr| {
238                use qail_core::ast::Expr;
239                match expr {
240                    Expr::Named(name) => Some(name.clone()),
241                    Expr::Aliased { name, .. } => Some(name.clone()),
242                    Expr::Star => None, // Can't COPY with *
243                    _ => None,
244                }
245            })
246            .collect();
247
248        if columns.is_empty() {
249            return Err(PgError::Query(
250                "copy_bulk requires columns in Qail".to_string(),
251            ));
252        }
253
254        // Use optimized COPY path: direct Value → bytes encoding, single syscall
255        self.connection.copy_in_fast(table, &columns, rows).await
256    }
257
258    /// **Fastest** bulk insert using pre-encoded COPY data.
259    /// Accepts raw COPY text format bytes. Use when caller has already
260    /// encoded rows to avoid any encoding overhead.
261    /// # Format
262    /// Data should be tab-separated rows with newlines (COPY text format):
263    /// `1\thello\t3.14\n2\tworld\t2.71\n`
264    /// # Example
265    /// ```ignore
266    /// let cmd = Qail::add("users").columns(["id", "name"]);
267    /// let data = b"1\tAlice\n2\tBob\n";
268    /// driver.copy_bulk_bytes(&cmd, data).await?;
269    /// ```
270    pub async fn copy_bulk_bytes(&mut self, cmd: &Qail, data: &[u8]) -> PgResult<u64> {
271        use qail_core::ast::Action;
272
273        if cmd.action != Action::Add {
274            return Err(PgError::Query(
275                "copy_bulk_bytes requires Qail::Add action".to_string(),
276            ));
277        }
278
279        let table = &cmd.table;
280        let columns: Vec<String> = cmd
281            .columns
282            .iter()
283            .filter_map(|expr| {
284                use qail_core::ast::Expr;
285                match expr {
286                    Expr::Named(name) => Some(name.clone()),
287                    Expr::Aliased { name, .. } => Some(name.clone()),
288                    _ => None,
289                }
290            })
291            .collect();
292
293        if columns.is_empty() {
294            return Err(PgError::Query(
295                "copy_bulk_bytes requires columns in Qail".to_string(),
296            ));
297        }
298
299        // Direct to raw COPY - zero encoding!
300        self.connection.copy_in_raw(table, &columns, data).await
301    }
302
303    /// Export table data using PostgreSQL COPY TO STDOUT (zero-copy streaming).
304    /// Returns rows as tab-separated bytes for direct re-import via copy_bulk_bytes.
305    /// # Example
306    /// ```ignore
307    /// let data = driver.copy_export_table("users", &["id", "name"]).await?;
308    /// shadow_driver.copy_bulk_bytes(&cmd, &data).await?;
309    /// ```
310    pub async fn copy_export_table(
311        &mut self,
312        table: &str,
313        columns: &[String],
314    ) -> PgResult<Vec<u8>> {
315        let quote_ident = |ident: &str| -> String {
316            format!("\"{}\"", ident.replace('\0', "").replace('"', "\"\""))
317        };
318        let cols: Vec<String> = columns.iter().map(|c| quote_ident(c)).collect();
319        let sql = format!(
320            "COPY {} ({}) TO STDOUT",
321            quote_ident(table),
322            cols.join(", ")
323        );
324
325        self.connection.copy_out_raw(&sql).await
326    }
327
328    /// Stream table export using COPY TO STDOUT with bounded memory usage.
329    ///
330    /// Chunks are forwarded directly from PostgreSQL to `on_chunk`.
331    pub async fn copy_export_table_stream<F, Fut>(
332        &mut self,
333        table: &str,
334        columns: &[String],
335        on_chunk: F,
336    ) -> PgResult<()>
337    where
338        F: FnMut(Vec<u8>) -> Fut,
339        Fut: std::future::Future<Output = PgResult<()>>,
340    {
341        let quote_ident = |ident: &str| -> String {
342            format!("\"{}\"", ident.replace('\0', "").replace('"', "\"\""))
343        };
344        let cols: Vec<String> = columns.iter().map(|c| quote_ident(c)).collect();
345        let sql = format!(
346            "COPY {} ({}) TO STDOUT",
347            quote_ident(table),
348            cols.join(", ")
349        );
350        self.connection.copy_out_raw_stream(&sql, on_chunk).await
351    }
352
353    /// Stream an AST-native `Qail::Export` command as raw COPY chunks.
354    pub async fn copy_export_cmd_stream<F, Fut>(&mut self, cmd: &Qail, on_chunk: F) -> PgResult<()>
355    where
356        F: FnMut(Vec<u8>) -> Fut,
357        Fut: std::future::Future<Output = PgResult<()>>,
358    {
359        self.connection.copy_export_stream_raw(cmd, on_chunk).await
360    }
361
362    /// Stream an AST-native `Qail::Export` command as parsed text rows.
363    pub async fn copy_export_cmd_stream_rows<F>(&mut self, cmd: &Qail, on_row: F) -> PgResult<()>
364    where
365        F: FnMut(Vec<String>) -> PgResult<()>,
366    {
367        self.connection.copy_export_stream_rows(cmd, on_row).await
368    }
369
370    /// Stream large result sets using PostgreSQL cursors.
371    /// This method uses DECLARE CURSOR internally to stream rows in batches,
372    /// avoiding loading the entire result set into memory.
373    /// # Example
374    /// ```ignore
375    /// let cmd = Qail::get("large_table");
376    /// let batches = driver.stream_cmd(&cmd, 100).await?;
377    /// for batch in batches {
378    ///     for row in batch {
379    ///         // process row
380    ///     }
381    /// }
382    /// ```
383    pub async fn stream_cmd(&mut self, cmd: &Qail, batch_size: usize) -> PgResult<Vec<Vec<PgRow>>> {
384        use std::sync::atomic::{AtomicU64, Ordering};
385        static CURSOR_ID: AtomicU64 = AtomicU64::new(0);
386
387        let cursor_name = format!("qail_cursor_{}", CURSOR_ID.fetch_add(1, Ordering::SeqCst));
388
389        // AST-NATIVE: Generate SQL directly from AST (no to_sql_parameterized!)
390        use crate::protocol::AstEncoder;
391        let mut sql_buf = bytes::BytesMut::with_capacity(256);
392        let mut params: Vec<Option<Vec<u8>>> = Vec::new();
393        AstEncoder::encode_select_sql(cmd, &mut sql_buf, &mut params)
394            .map_err(|e| PgError::Encode(e.to_string()))?;
395        let sql = String::from_utf8_lossy(&sql_buf).to_string();
396
397        // Must be in a transaction for cursors
398        self.connection.begin_transaction().await?;
399
400        // Declare cursor
401        // Declare cursor with bind params — Extended Query Protocol handles $1, $2 etc.
402        self.connection
403            .declare_cursor(&cursor_name, &sql, &params)
404            .await?;
405
406        // Fetch all batches
407        let mut all_batches = Vec::new();
408        while let Some(rows) = self
409            .connection
410            .fetch_cursor(&cursor_name, batch_size)
411            .await?
412        {
413            let pg_rows: Vec<PgRow> = rows
414                .into_iter()
415                .map(|cols| PgRow {
416                    columns: cols,
417                    column_info: None,
418                })
419                .collect();
420            all_batches.push(pg_rows);
421        }
422
423        self.connection.close_cursor(&cursor_name).await?;
424        self.connection.commit().await?;
425
426        Ok(all_batches)
427    }
428}