Skip to main content

qail_pg/driver/
ops.rs

1//! PgDriver operations: transaction control, batch execution, statement timeout,
2//! RLS context, pipeline, legacy/raw SQL, COPY bulk/export, and cursor streaming.
3
4use super::core::PgDriver;
5use super::prepared::PreparedStatement;
6use super::rls;
7use super::types::*;
8use qail_core::ast::Qail;
9
10impl PgDriver {
11    // ==================== TRANSACTION CONTROL ====================
12
13    /// Begin a transaction (AST-native).
14    pub async fn begin(&mut self) -> PgResult<()> {
15        self.connection.begin_transaction().await
16    }
17
18    /// Commit the current transaction (AST-native).
19    pub async fn commit(&mut self) -> PgResult<()> {
20        self.connection.commit().await
21    }
22
23    /// Rollback the current transaction (AST-native).
24    pub async fn rollback(&mut self) -> PgResult<()> {
25        self.connection.rollback().await
26    }
27
28    /// Create a named savepoint within the current transaction.
29    /// Savepoints allow partial rollback within a transaction.
30    /// Use `rollback_to()` to return to this savepoint.
31    /// # Example
32    /// ```ignore
33    /// driver.begin().await?;
34    /// driver.execute(&insert1).await?;
35    /// driver.savepoint("sp1").await?;
36    /// driver.execute(&insert2).await?;
37    /// driver.rollback_to("sp1").await?; // Undo insert2, keep insert1
38    /// driver.commit().await?;
39    /// ```
40    pub async fn savepoint(&mut self, name: &str) -> PgResult<()> {
41        self.connection.savepoint(name).await
42    }
43
44    /// Rollback to a previously created savepoint.
45    /// Discards all changes since the named savepoint was created,
46    /// but keeps the transaction open.
47    pub async fn rollback_to(&mut self, name: &str) -> PgResult<()> {
48        self.connection.rollback_to(name).await
49    }
50
51    /// Release a savepoint (free resources, if no longer needed).
52    /// After release, the savepoint cannot be rolled back to.
53    pub async fn release_savepoint(&mut self, name: &str) -> PgResult<()> {
54        self.connection.release_savepoint(name).await
55    }
56
57    // ==================== BATCH TRANSACTIONS ====================
58
59    /// Execute multiple commands in a single atomic transaction.
60    /// All commands succeed or all are rolled back.
61    /// # Example
62    /// ```ignore
63    /// let cmds = vec![
64    ///     Qail::add("users").columns(["name"]).values(["Alice"]),
65    ///     Qail::add("users").columns(["name"]).values(["Bob"]),
66    /// ];
67    /// let results = driver.execute_batch(&cmds).await?;
68    /// // results = [1, 1] (rows affected)
69    /// ```
70    pub async fn execute_batch(&mut self, cmds: &[Qail]) -> PgResult<Vec<u64>> {
71        self.begin().await?;
72        let mut results = Vec::with_capacity(cmds.len());
73        for cmd in cmds {
74            match self.execute(cmd).await {
75                Ok(n) => results.push(n),
76                Err(e) => {
77                    self.rollback().await?;
78                    return Err(e);
79                }
80            }
81        }
82        self.commit().await?;
83        Ok(results)
84    }
85
86    // ==================== STATEMENT TIMEOUT ====================
87
88    /// Set statement timeout for this connection (in milliseconds).
89    /// # Example
90    /// ```ignore
91    /// driver.set_statement_timeout(30_000).await?; // 30 seconds
92    /// ```
93    pub async fn set_statement_timeout(&mut self, ms: u32) -> PgResult<()> {
94        self.execute_raw(&format!("SET statement_timeout = {}", ms))
95            .await
96    }
97
98    /// Reset statement timeout to default (no limit).
99    pub async fn reset_statement_timeout(&mut self) -> PgResult<()> {
100        self.execute_raw("RESET statement_timeout").await
101    }
102
103    // ==================== RLS (MULTI-TENANT) ====================
104
105    /// Set the RLS context for multi-tenant data isolation.
106    ///
107    /// Configures PostgreSQL session variables (`app.current_operator_id`, etc.)
108    /// so that RLS policies automatically filter data by tenant.
109    ///
110    /// Since `PgDriver` takes `&mut self`, the borrow checker guarantees
111    /// that `set_config` and all subsequent queries execute on the **same
112    /// connection** — no pool race conditions possible.
113    ///
114    /// # Example
115    /// ```ignore
116    /// driver.set_rls_context(RlsContext::operator("op-123")).await?;
117    /// let orders = driver.fetch_all(&Qail::get("orders")).await?;
118    /// // orders only contains rows where operator_id = 'op-123'
119    /// ```
120    pub async fn set_rls_context(&mut self, ctx: rls::RlsContext) -> PgResult<()> {
121        let sql = rls::context_to_sql(&ctx);
122        self.execute_raw(&sql).await?;
123        self.rls_context = Some(ctx);
124        Ok(())
125    }
126
127    /// Clear the RLS context, resetting session variables to safe defaults.
128    ///
129    /// After clearing, all RLS-protected queries will return zero rows
130    /// (empty operator_id matches nothing).
131    pub async fn clear_rls_context(&mut self) -> PgResult<()> {
132        self.execute_raw(rls::reset_sql()).await?;
133        self.rls_context = None;
134        Ok(())
135    }
136
137    /// Get the current RLS context, if any.
138    pub fn rls_context(&self) -> Option<&rls::RlsContext> {
139        self.rls_context.as_ref()
140    }
141
142    // ==================== PIPELINE (BATCH) ====================
143
144    /// Execute multiple Qail ASTs in a single network round-trip (PIPELINING).
145    /// # Example
146    /// ```ignore
147    /// let cmds: Vec<Qail> = (1..=1000)
148    ///     .map(|i| Qail::get("harbors").columns(["id", "name"]).limit(i))
149    ///     .collect();
150    /// let count = driver.pipeline_batch(&cmds).await?;
151    /// assert_eq!(count, 1000);
152    /// ```
153    pub async fn pipeline_batch(&mut self, cmds: &[Qail]) -> PgResult<usize> {
154        self.connection.pipeline_ast_fast(cmds).await
155    }
156
157    /// Execute multiple Qail ASTs and return full row data.
158    pub async fn pipeline_fetch(&mut self, cmds: &[Qail]) -> PgResult<Vec<Vec<PgRow>>> {
159        let raw_results = self.connection.pipeline_ast(cmds).await?;
160
161        let results: Vec<Vec<PgRow>> = raw_results
162            .into_iter()
163            .map(|rows| {
164                rows.into_iter()
165                    .map(|columns| PgRow {
166                        columns,
167                        column_info: None,
168                    })
169                    .collect()
170            })
171            .collect();
172
173        Ok(results)
174    }
175
176    /// Prepare a SQL statement for repeated execution.
177    pub async fn prepare(&mut self, sql: &str) -> PgResult<PreparedStatement> {
178        self.connection.prepare(sql).await
179    }
180
181    /// Execute a prepared statement pipeline in FAST mode (count only).
182    pub async fn pipeline_prepared_fast(
183        &mut self,
184        stmt: &PreparedStatement,
185        params_batch: &[Vec<Option<Vec<u8>>>],
186    ) -> PgResult<usize> {
187        self.connection
188            .pipeline_prepared_fast(stmt, params_batch)
189            .await
190    }
191
192    // ==================== LEGACY/BOOTSTRAP ====================
193
194    /// Execute a raw SQL string.
195    /// ⚠️ **Discouraged**: Violates AST-native philosophy.
196    /// Use for bootstrap DDL only (e.g., migration table creation).
197    /// For transactions, use `begin()`, `commit()`, `rollback()`.
198    pub async fn execute_raw(&mut self, sql: &str) -> PgResult<()> {
199        // Reject literal NULL bytes - they corrupt PostgreSQL connection state
200        if sql.as_bytes().contains(&0) {
201            return Err(crate::PgError::Protocol(
202                "SQL contains NULL byte (0x00) which is invalid in PostgreSQL".to_string(),
203            ));
204        }
205        self.connection.execute_simple(sql).await
206    }
207
208    /// Execute a raw SQL query and return rows.
209    /// ⚠️ **Discouraged**: Violates AST-native philosophy.
210    /// Use for bootstrap/admin queries only.
211    pub async fn fetch_raw(&mut self, sql: &str) -> PgResult<Vec<PgRow>> {
212        if sql.as_bytes().contains(&0) {
213            return Err(crate::PgError::Protocol(
214                "SQL contains NULL byte (0x00) which is invalid in PostgreSQL".to_string(),
215            ));
216        }
217
218        use crate::protocol::PgEncoder;
219        use tokio::io::AsyncWriteExt;
220
221        // Use simple query protocol (no prepared statements)
222        let msg = PgEncoder::try_encode_query_string(sql)?;
223        self.connection.stream.write_all(&msg).await?;
224
225        let mut rows: Vec<PgRow> = Vec::new();
226        let mut column_info: Option<std::sync::Arc<ColumnInfo>> = None;
227
228        let mut error: Option<PgError> = None;
229
230        loop {
231            let msg = self.connection.recv().await?;
232            match msg {
233                crate::protocol::BackendMessage::RowDescription(fields) => {
234                    column_info = Some(std::sync::Arc::new(ColumnInfo::from_fields(&fields)));
235                }
236                crate::protocol::BackendMessage::DataRow(data) => {
237                    if error.is_none() {
238                        rows.push(PgRow {
239                            columns: data,
240                            column_info: column_info.clone(),
241                        });
242                    }
243                }
244                crate::protocol::BackendMessage::CommandComplete(_) => {}
245                crate::protocol::BackendMessage::ReadyForQuery(_) => {
246                    if let Some(err) = error {
247                        return Err(err);
248                    }
249                    return Ok(rows);
250                }
251                crate::protocol::BackendMessage::ErrorResponse(err) => {
252                    if error.is_none() {
253                        error = Some(PgError::QueryServer(err.into()));
254                    }
255                }
256                msg if is_ignorable_session_message(&msg) => {}
257                other => return Err(unexpected_backend_message("driver fetch_raw", &other)),
258            }
259        }
260    }
261
262    /// Bulk insert data using PostgreSQL COPY protocol (AST-native).
263    /// Uses a Qail::Add to get validated table and column names from the AST,
264    /// not user-provided strings. This is the sound, AST-native approach.
265    /// # Example
266    /// ```ignore
267    /// // Create a Qail::Add to define table and columns
268    /// let cmd = Qail::add("users")
269    ///     .columns(["id", "name", "email"]);
270    /// // Bulk insert rows
271    /// let rows: Vec<Vec<Value>> = vec![
272    ///     vec![Value::Int(1), Value::String("Alice"), Value::String("alice@ex.com")],
273    ///     vec![Value::Int(2), Value::String("Bob"), Value::String("bob@ex.com")],
274    /// ];
275    /// driver.copy_bulk(&cmd, &rows).await?;
276    /// ```
277    pub async fn copy_bulk(
278        &mut self,
279        cmd: &Qail,
280        rows: &[Vec<qail_core::ast::Value>],
281    ) -> PgResult<u64> {
282        use qail_core::ast::Action;
283
284        if cmd.action != Action::Add {
285            return Err(PgError::Query(
286                "copy_bulk requires Qail::Add action".to_string(),
287            ));
288        }
289
290        let table = &cmd.table;
291
292        let columns: Vec<String> = cmd
293            .columns
294            .iter()
295            .filter_map(|expr| {
296                use qail_core::ast::Expr;
297                match expr {
298                    Expr::Named(name) => Some(name.clone()),
299                    Expr::Aliased { name, .. } => Some(name.clone()),
300                    Expr::Star => None, // Can't COPY with *
301                    _ => None,
302                }
303            })
304            .collect();
305
306        if columns.is_empty() {
307            return Err(PgError::Query(
308                "copy_bulk requires columns in Qail".to_string(),
309            ));
310        }
311
312        // Use optimized COPY path: direct Value → bytes encoding, single syscall
313        self.connection.copy_in_fast(table, &columns, rows).await
314    }
315
316    /// **Fastest** bulk insert using pre-encoded COPY data.
317    /// Accepts raw COPY text format bytes. Use when caller has already
318    /// encoded rows to avoid any encoding overhead.
319    /// # Format
320    /// Data should be tab-separated rows with newlines (COPY text format):
321    /// `1\thello\t3.14\n2\tworld\t2.71\n`
322    /// # Example
323    /// ```ignore
324    /// let cmd = Qail::add("users").columns(["id", "name"]);
325    /// let data = b"1\tAlice\n2\tBob\n";
326    /// driver.copy_bulk_bytes(&cmd, data).await?;
327    /// ```
328    pub async fn copy_bulk_bytes(&mut self, cmd: &Qail, data: &[u8]) -> PgResult<u64> {
329        use qail_core::ast::Action;
330
331        if cmd.action != Action::Add {
332            return Err(PgError::Query(
333                "copy_bulk_bytes requires Qail::Add action".to_string(),
334            ));
335        }
336
337        let table = &cmd.table;
338        let columns: Vec<String> = cmd
339            .columns
340            .iter()
341            .filter_map(|expr| {
342                use qail_core::ast::Expr;
343                match expr {
344                    Expr::Named(name) => Some(name.clone()),
345                    Expr::Aliased { name, .. } => Some(name.clone()),
346                    _ => None,
347                }
348            })
349            .collect();
350
351        if columns.is_empty() {
352            return Err(PgError::Query(
353                "copy_bulk_bytes requires columns in Qail".to_string(),
354            ));
355        }
356
357        // Direct to raw COPY - zero encoding!
358        self.connection.copy_in_raw(table, &columns, data).await
359    }
360
361    /// Export table data using PostgreSQL COPY TO STDOUT (zero-copy streaming).
362    /// Returns rows as tab-separated bytes for direct re-import via copy_bulk_bytes.
363    /// # Example
364    /// ```ignore
365    /// let data = driver.copy_export_table("users", &["id", "name"]).await?;
366    /// shadow_driver.copy_bulk_bytes(&cmd, &data).await?;
367    /// ```
368    pub async fn copy_export_table(
369        &mut self,
370        table: &str,
371        columns: &[String],
372    ) -> PgResult<Vec<u8>> {
373        let quote_ident = |ident: &str| -> String {
374            format!("\"{}\"", ident.replace('\0', "").replace('"', "\"\""))
375        };
376        let cols: Vec<String> = columns.iter().map(|c| quote_ident(c)).collect();
377        let sql = format!(
378            "COPY {} ({}) TO STDOUT",
379            quote_ident(table),
380            cols.join(", ")
381        );
382
383        self.connection.copy_out_raw(&sql).await
384    }
385
386    /// Stream table export using COPY TO STDOUT with bounded memory usage.
387    ///
388    /// Chunks are forwarded directly from PostgreSQL to `on_chunk`.
389    pub async fn copy_export_table_stream<F, Fut>(
390        &mut self,
391        table: &str,
392        columns: &[String],
393        on_chunk: F,
394    ) -> PgResult<()>
395    where
396        F: FnMut(Vec<u8>) -> Fut,
397        Fut: std::future::Future<Output = PgResult<()>>,
398    {
399        let quote_ident = |ident: &str| -> String {
400            format!("\"{}\"", ident.replace('\0', "").replace('"', "\"\""))
401        };
402        let cols: Vec<String> = columns.iter().map(|c| quote_ident(c)).collect();
403        let sql = format!(
404            "COPY {} ({}) TO STDOUT",
405            quote_ident(table),
406            cols.join(", ")
407        );
408        self.connection.copy_out_raw_stream(&sql, on_chunk).await
409    }
410
411    /// Stream an AST-native `Qail::Export` command as raw COPY chunks.
412    pub async fn copy_export_cmd_stream<F, Fut>(&mut self, cmd: &Qail, on_chunk: F) -> PgResult<()>
413    where
414        F: FnMut(Vec<u8>) -> Fut,
415        Fut: std::future::Future<Output = PgResult<()>>,
416    {
417        self.connection.copy_export_stream_raw(cmd, on_chunk).await
418    }
419
420    /// Stream an AST-native `Qail::Export` command as parsed text rows.
421    pub async fn copy_export_cmd_stream_rows<F>(&mut self, cmd: &Qail, on_row: F) -> PgResult<()>
422    where
423        F: FnMut(Vec<String>) -> PgResult<()>,
424    {
425        self.connection.copy_export_stream_rows(cmd, on_row).await
426    }
427
428    /// Stream large result sets using PostgreSQL cursors.
429    /// This method uses DECLARE CURSOR internally to stream rows in batches,
430    /// avoiding loading the entire result set into memory.
431    /// # Example
432    /// ```ignore
433    /// let cmd = Qail::get("large_table");
434    /// let batches = driver.stream_cmd(&cmd, 100).await?;
435    /// for batch in batches {
436    ///     for row in batch {
437    ///         // process row
438    ///     }
439    /// }
440    /// ```
441    pub async fn stream_cmd(&mut self, cmd: &Qail, batch_size: usize) -> PgResult<Vec<Vec<PgRow>>> {
442        use std::sync::atomic::{AtomicU64, Ordering};
443        static CURSOR_ID: AtomicU64 = AtomicU64::new(0);
444
445        let cursor_name = format!("qail_cursor_{}", CURSOR_ID.fetch_add(1, Ordering::SeqCst));
446
447        // AST-NATIVE: Generate SQL directly from AST (no to_sql_parameterized!)
448        use crate::protocol::AstEncoder;
449        let mut sql_buf = bytes::BytesMut::with_capacity(256);
450        let mut params: Vec<Option<Vec<u8>>> = Vec::new();
451        AstEncoder::encode_select_sql(cmd, &mut sql_buf, &mut params)
452            .map_err(|e| PgError::Encode(e.to_string()))?;
453        let sql = String::from_utf8_lossy(&sql_buf).to_string();
454
455        // Must be in a transaction for cursors
456        self.connection.begin_transaction().await?;
457
458        // Declare cursor
459        // Declare cursor with bind params — Extended Query Protocol handles $1, $2 etc.
460        self.connection
461            .declare_cursor(&cursor_name, &sql, &params)
462            .await?;
463
464        // Fetch all batches
465        let mut all_batches = Vec::new();
466        while let Some(rows) = self
467            .connection
468            .fetch_cursor(&cursor_name, batch_size)
469            .await?
470        {
471            let pg_rows: Vec<PgRow> = rows
472                .into_iter()
473                .map(|cols| PgRow {
474                    columns: cols,
475                    column_info: None,
476                })
477                .collect();
478            all_batches.push(pg_rows);
479        }
480
481        self.connection.close_cursor(&cursor_name).await?;
482        self.connection.commit().await?;
483
484        Ok(all_batches)
485    }
486}