Skip to main content

pg_wired/
pipeline.rs

1//! Pipelined query execution: fuse `Parse`/`Bind`/`Execute`/`Sync` for
2//! multiple statements into one `write()` and one `flush()`. Reuses prepared
3//! statements via an LRU statement cache.
4
5use std::collections::HashMap;
6
7use bytes::BytesMut;
8
9use crate::connection::WireConn;
10use crate::error::PgWireError;
11use crate::protocol::frontend;
12use crate::protocol::types::{BackendMsg, FormatCode, FrontendMsg, RawRow};
13
14/// High-level pipelined PostgreSQL client.
15/// Coalesces Parse+Bind+Execute+Sync into a single TCP write.
16/// Caches prepared statements to skip Parse on subsequent calls.
17pub struct PgPipeline {
18    conn: WireConn,
19    stmt_cache: HashMap<String, String>, // sql → statement name
20    stmt_counter: u64,
21    max_cache_size: usize,
22    send_buf: BytesMut,
23}
24
25impl std::fmt::Debug for PgPipeline {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        f.debug_struct("PgPipeline")
28            .field("conn", &self.conn)
29            .field("cached_statements", &self.stmt_cache.len())
30            .field("max_cache_size", &self.max_cache_size)
31            .finish()
32    }
33}
34
35impl PgPipeline {
36    /// Wrap a connected [`WireConn`] in a pipelining helper.
37    /// Statement cache size defaults to 256 entries.
38    pub fn new(conn: WireConn) -> Self {
39        Self {
40            conn,
41            stmt_cache: HashMap::new(),
42            stmt_counter: 0,
43            max_cache_size: 256,
44            send_buf: BytesMut::with_capacity(4096),
45        }
46    }
47
48    /// Execute a parameterized query, returning rows as `Vec<RawRow>`.
49    /// Uses binary format for parameters and results.
50    /// On cache miss: Parse+Bind+Execute+Sync in ONE write.
51    /// On cache hit: Bind+Execute+Sync in ONE write.
52    pub async fn query(
53        &mut self,
54        sql: &str,
55        params: &[Option<&[u8]>],
56        param_oids: &[u32],
57    ) -> Result<Vec<RawRow>, PgWireError> {
58        let (stmt_name, needs_parse) = self.lookup_or_alloc(sql);
59        let stmt_name_bytes = stmt_name.as_bytes().to_vec();
60
61        // Build all messages into one buffer.
62        self.send_buf.clear();
63
64        // Send params as Text format — our values are text strings that PostgreSQL
65        // casts via ($1::text)::target_type in the SQL.
66        let text_fmts: Vec<FormatCode> = vec![FormatCode::Text; params.len().max(1)];
67        let result_fmts = [FormatCode::Text]; // Text for JSON result
68
69        let mut msgs: Vec<FrontendMsg<'_>> = Vec::with_capacity(4);
70
71        if needs_parse {
72            msgs.push(FrontendMsg::Parse {
73                name: &stmt_name_bytes,
74                sql: sql.as_bytes(),
75                param_oids,
76            });
77        }
78
79        msgs.push(FrontendMsg::Bind {
80            portal: b"",
81            statement: &stmt_name_bytes,
82            param_formats: &text_fmts[..params.len()],
83            params,
84            result_formats: &result_fmts,
85        });
86
87        msgs.push(FrontendMsg::Execute {
88            portal: b"",
89            max_rows: 0, // unlimited
90        });
91
92        msgs.push(FrontendMsg::Sync);
93
94        // Encode all messages into ONE buffer.
95        frontend::encode_messages(&msgs, &mut self.send_buf);
96
97        // ONE write() syscall.
98        self.conn.send_raw(&self.send_buf).await?;
99
100        // Collect all rows until ReadyForQuery.
101        let (rows, _tag) = self.conn.collect_rows().await?;
102        Ok(rows)
103    }
104
105    /// Execute a simple query (no parameters, text protocol).
106    /// Used for SET LOCAL ROLE, set_config, BEGIN, COMMIT etc.
107    ///
108    /// Drains all backend messages until `ReadyForQuery`. If the server sent
109    /// an `ErrorResponse`, the first such error is captured and returned as
110    /// `Err(PgWireError::Pg(_))` once the resync completes.
111    pub async fn simple_query(&mut self, sql: &str) -> Result<(), PgWireError> {
112        self.send_buf.clear();
113        frontend::encode_message(&FrontendMsg::Query(sql.as_bytes()), &mut self.send_buf);
114        self.conn.send_raw(&self.send_buf).await?;
115
116        let mut first_error = None;
117        loop {
118            match self.conn.recv_msg().await? {
119                BackendMsg::ReadyForQuery { .. } => break,
120                BackendMsg::ErrorResponse { fields } if first_error.is_none() => {
121                    first_error = Some(fields);
122                }
123                _ => {}
124            }
125        }
126        if let Some(err) = first_error {
127            return Err(PgWireError::Pg(err));
128        }
129        Ok(())
130    }
131
132    /// Execute a simple query and return rows (text format).
133    /// Used for introspection queries (e.g., migration status).
134    pub async fn simple_query_rows(
135        &mut self,
136        sql: &str,
137    ) -> Result<(Vec<RawRow>, String), PgWireError> {
138        self.send_buf.clear();
139        frontend::encode_message(&FrontendMsg::Query(sql.as_bytes()), &mut self.send_buf);
140        self.conn.send_raw(&self.send_buf).await?;
141        self.conn.collect_rows().await
142    }
143
144    /// Execute a pipelined transaction: setup (simple) + query (parameterized) in TWO messages
145    /// but coalesced into ONE TCP write.
146    ///
147    /// This is the key optimization: BEGIN + SET LOCAL ROLE + set_config + parameterized query
148    /// all go in one write() syscall, with the data query using the safe binary protocol.
149    pub async fn pipeline_with_setup(
150        &mut self,
151        setup_sql: &str,
152        query_sql: &str,
153        params: &[Option<&[u8]>],
154        param_oids: &[u32],
155    ) -> Result<Vec<RawRow>, PgWireError> {
156        let (stmt_name, needs_parse) = self.lookup_or_alloc(query_sql);
157        let stmt_name_bytes = stmt_name.as_bytes().to_vec();
158
159        self.send_buf.clear();
160
161        // 1. Simple query for setup (BEGIN + SET ROLE + set_config).
162        frontend::encode_message(
163            &FrontendMsg::Query(setup_sql.as_bytes()),
164            &mut self.send_buf,
165        );
166
167        // 2. Extended query for data (Parse? + Bind + Execute + Sync).
168        // Send params as Text format — our values are text strings that PostgreSQL
169        // casts via ($1::text)::target_type in the SQL.
170        let text_fmts: Vec<FormatCode> = vec![FormatCode::Text; params.len().max(1)];
171        let result_fmts = [FormatCode::Text];
172
173        if needs_parse {
174            frontend::encode_message(
175                &FrontendMsg::Parse {
176                    name: &stmt_name_bytes,
177                    sql: query_sql.as_bytes(),
178                    param_oids,
179                },
180                &mut self.send_buf,
181            );
182        }
183
184        frontend::encode_message(
185            &FrontendMsg::Bind {
186                portal: b"",
187                statement: &stmt_name_bytes,
188                param_formats: &text_fmts[..params.len()],
189                params,
190                result_formats: &result_fmts,
191            },
192            &mut self.send_buf,
193        );
194
195        frontend::encode_message(
196            &FrontendMsg::Execute {
197                portal: b"",
198                max_rows: 0,
199            },
200            &mut self.send_buf,
201        );
202
203        frontend::encode_message(&FrontendMsg::Sync, &mut self.send_buf);
204
205        // ONE write() syscall for everything.
206        self.conn.send_raw(&self.send_buf).await?;
207
208        // Read responses: first ReadyForQuery from the simple query setup,
209        // then DataRows + ReadyForQuery from the extended query.
210        self.conn.drain_until_ready().await?; // Setup response
211        let (rows, _tag) = self.conn.collect_rows().await?; // Data response
212
213        Ok(rows)
214    }
215
216    /// Execute a pipelined transaction with COMMIT at the end.
217    /// setup_sql should contain "BEGIN; SET LOCAL ROLE ...; SELECT set_config(...)"
218    /// The commit is sent as a separate simple query, coalesced in the same write.
219    pub async fn pipeline_transaction(
220        &mut self,
221        setup_sql: &str,
222        query_sql: &str,
223        params: &[Option<&[u8]>],
224        param_oids: &[u32],
225    ) -> Result<Vec<RawRow>, PgWireError> {
226        let (stmt_name, needs_parse) = self.lookup_or_alloc(query_sql);
227        let stmt_name_bytes = stmt_name.as_bytes().to_vec();
228
229        self.send_buf.clear();
230
231        // 1. Simple query: BEGIN + SET ROLE + set_config
232        frontend::encode_message(
233            &FrontendMsg::Query(setup_sql.as_bytes()),
234            &mut self.send_buf,
235        );
236
237        // 2. Extended query: Bind + Execute + Sync (parameterized, binary safe)
238        // Send params as Text format — our values are text strings that PostgreSQL
239        // casts via ($1::text)::target_type in the SQL.
240        let text_fmts: Vec<FormatCode> = vec![FormatCode::Text; params.len().max(1)];
241        let result_fmts = [FormatCode::Text];
242
243        if needs_parse {
244            frontend::encode_message(
245                &FrontendMsg::Parse {
246                    name: &stmt_name_bytes,
247                    sql: query_sql.as_bytes(),
248                    param_oids,
249                },
250                &mut self.send_buf,
251            );
252        }
253
254        frontend::encode_message(
255            &FrontendMsg::Bind {
256                portal: b"",
257                statement: &stmt_name_bytes,
258                param_formats: &text_fmts[..params.len()],
259                params,
260                result_formats: &result_fmts,
261            },
262            &mut self.send_buf,
263        );
264
265        frontend::encode_message(
266            &FrontendMsg::Execute {
267                portal: b"",
268                max_rows: 0,
269            },
270            &mut self.send_buf,
271        );
272
273        frontend::encode_message(&FrontendMsg::Sync, &mut self.send_buf);
274
275        // 3. Simple query: COMMIT
276        frontend::encode_message(&FrontendMsg::Query(b"COMMIT"), &mut self.send_buf);
277
278        // ONE write() syscall for the entire transaction.
279        self.conn.send_raw(&self.send_buf).await?;
280
281        // Read responses in order:
282        // 1. ReadyForQuery from setup
283        // 2. DataRows + ReadyForQuery from data query
284        // 3. ReadyForQuery from COMMIT
285        self.conn.drain_until_ready().await?; // Setup
286        let (rows, _tag) = self.conn.collect_rows().await?; // Data
287        self.conn.drain_until_ready().await?; // COMMIT
288
289        Ok(rows)
290    }
291
292    /// Look up or allocate a statement name.
293    fn lookup_or_alloc(&mut self, sql: &str) -> (String, bool) {
294        if let Some(name) = self.stmt_cache.get(sql) {
295            return (name.clone(), false);
296        }
297
298        // Evict if cache is full.
299        if self.stmt_cache.len() >= self.max_cache_size {
300            // Simple eviction: clear all (LRU would be better).
301            self.stmt_cache.clear();
302        }
303
304        let name = format!("s{}", self.stmt_counter);
305        self.stmt_counter += 1;
306        self.stmt_cache.insert(sql.to_string(), name.clone());
307        (name, true)
308    }
309
310    /// Clear the statement cache (e.g., after DISCARD ALL).
311    pub fn clear_cache(&mut self) {
312        self.stmt_cache.clear();
313    }
314
315    /// Get a mutable reference to the underlying connection.
316    pub fn conn(&mut self) -> &mut WireConn {
317        &mut self.conn
318    }
319
320    /// Get a reference to the underlying connection.
321    pub fn conn_ref(&self) -> &WireConn {
322        &self.conn
323    }
324}