1use std::collections::HashMap;
6
7use bytes::BytesMut;
8
9use crate::connection::WireConn;
10use crate::error::PgWireError;
11use crate::protocol::frontend;
12use crate::protocol::types::{BackendMsg, FormatCode, FrontendMsg, RawRow};
13
14pub struct PgPipeline {
18 conn: WireConn,
19 stmt_cache: HashMap<String, String>, stmt_counter: u64,
21 max_cache_size: usize,
22 send_buf: BytesMut,
23}
24
25impl std::fmt::Debug for PgPipeline {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 f.debug_struct("PgPipeline")
28 .field("conn", &self.conn)
29 .field("cached_statements", &self.stmt_cache.len())
30 .field("max_cache_size", &self.max_cache_size)
31 .finish()
32 }
33}
34
35impl PgPipeline {
36 pub fn new(conn: WireConn) -> Self {
39 Self {
40 conn,
41 stmt_cache: HashMap::new(),
42 stmt_counter: 0,
43 max_cache_size: 256,
44 send_buf: BytesMut::with_capacity(4096),
45 }
46 }
47
48 pub async fn query(
53 &mut self,
54 sql: &str,
55 params: &[Option<&[u8]>],
56 param_oids: &[u32],
57 ) -> Result<Vec<RawRow>, PgWireError> {
58 let (stmt_name, needs_parse) = self.lookup_or_alloc(sql);
59 let stmt_name_bytes = stmt_name.as_bytes().to_vec();
60
61 self.send_buf.clear();
63
64 let text_fmts: Vec<FormatCode> = vec![FormatCode::Text; params.len().max(1)];
67 let result_fmts = [FormatCode::Text]; let mut msgs: Vec<FrontendMsg<'_>> = Vec::with_capacity(4);
70
71 if needs_parse {
72 msgs.push(FrontendMsg::Parse {
73 name: &stmt_name_bytes,
74 sql: sql.as_bytes(),
75 param_oids,
76 });
77 }
78
79 msgs.push(FrontendMsg::Bind {
80 portal: b"",
81 statement: &stmt_name_bytes,
82 param_formats: &text_fmts[..params.len()],
83 params,
84 result_formats: &result_fmts,
85 });
86
87 msgs.push(FrontendMsg::Execute {
88 portal: b"",
89 max_rows: 0, });
91
92 msgs.push(FrontendMsg::Sync);
93
94 frontend::encode_messages(&msgs, &mut self.send_buf);
96
97 self.conn.send_raw(&self.send_buf).await?;
99
100 let (rows, _tag) = self.conn.collect_rows().await?;
102 Ok(rows)
103 }
104
105 pub async fn simple_query(&mut self, sql: &str) -> Result<(), PgWireError> {
112 self.send_buf.clear();
113 frontend::encode_message(&FrontendMsg::Query(sql.as_bytes()), &mut self.send_buf);
114 self.conn.send_raw(&self.send_buf).await?;
115
116 let mut first_error = None;
117 loop {
118 match self.conn.recv_msg().await? {
119 BackendMsg::ReadyForQuery { .. } => break,
120 BackendMsg::ErrorResponse { fields } if first_error.is_none() => {
121 first_error = Some(fields);
122 }
123 _ => {}
124 }
125 }
126 if let Some(err) = first_error {
127 return Err(PgWireError::Pg(err));
128 }
129 Ok(())
130 }
131
132 pub async fn simple_query_rows(
135 &mut self,
136 sql: &str,
137 ) -> Result<(Vec<RawRow>, String), PgWireError> {
138 self.send_buf.clear();
139 frontend::encode_message(&FrontendMsg::Query(sql.as_bytes()), &mut self.send_buf);
140 self.conn.send_raw(&self.send_buf).await?;
141 self.conn.collect_rows().await
142 }
143
144 pub async fn pipeline_with_setup(
150 &mut self,
151 setup_sql: &str,
152 query_sql: &str,
153 params: &[Option<&[u8]>],
154 param_oids: &[u32],
155 ) -> Result<Vec<RawRow>, PgWireError> {
156 let (stmt_name, needs_parse) = self.lookup_or_alloc(query_sql);
157 let stmt_name_bytes = stmt_name.as_bytes().to_vec();
158
159 self.send_buf.clear();
160
161 frontend::encode_message(
163 &FrontendMsg::Query(setup_sql.as_bytes()),
164 &mut self.send_buf,
165 );
166
167 let text_fmts: Vec<FormatCode> = vec![FormatCode::Text; params.len().max(1)];
171 let result_fmts = [FormatCode::Text];
172
173 if needs_parse {
174 frontend::encode_message(
175 &FrontendMsg::Parse {
176 name: &stmt_name_bytes,
177 sql: query_sql.as_bytes(),
178 param_oids,
179 },
180 &mut self.send_buf,
181 );
182 }
183
184 frontend::encode_message(
185 &FrontendMsg::Bind {
186 portal: b"",
187 statement: &stmt_name_bytes,
188 param_formats: &text_fmts[..params.len()],
189 params,
190 result_formats: &result_fmts,
191 },
192 &mut self.send_buf,
193 );
194
195 frontend::encode_message(
196 &FrontendMsg::Execute {
197 portal: b"",
198 max_rows: 0,
199 },
200 &mut self.send_buf,
201 );
202
203 frontend::encode_message(&FrontendMsg::Sync, &mut self.send_buf);
204
205 self.conn.send_raw(&self.send_buf).await?;
207
208 self.conn.drain_until_ready().await?; let (rows, _tag) = self.conn.collect_rows().await?; Ok(rows)
214 }
215
216 pub async fn pipeline_transaction(
220 &mut self,
221 setup_sql: &str,
222 query_sql: &str,
223 params: &[Option<&[u8]>],
224 param_oids: &[u32],
225 ) -> Result<Vec<RawRow>, PgWireError> {
226 let (stmt_name, needs_parse) = self.lookup_or_alloc(query_sql);
227 let stmt_name_bytes = stmt_name.as_bytes().to_vec();
228
229 self.send_buf.clear();
230
231 frontend::encode_message(
233 &FrontendMsg::Query(setup_sql.as_bytes()),
234 &mut self.send_buf,
235 );
236
237 let text_fmts: Vec<FormatCode> = vec![FormatCode::Text; params.len().max(1)];
241 let result_fmts = [FormatCode::Text];
242
243 if needs_parse {
244 frontend::encode_message(
245 &FrontendMsg::Parse {
246 name: &stmt_name_bytes,
247 sql: query_sql.as_bytes(),
248 param_oids,
249 },
250 &mut self.send_buf,
251 );
252 }
253
254 frontend::encode_message(
255 &FrontendMsg::Bind {
256 portal: b"",
257 statement: &stmt_name_bytes,
258 param_formats: &text_fmts[..params.len()],
259 params,
260 result_formats: &result_fmts,
261 },
262 &mut self.send_buf,
263 );
264
265 frontend::encode_message(
266 &FrontendMsg::Execute {
267 portal: b"",
268 max_rows: 0,
269 },
270 &mut self.send_buf,
271 );
272
273 frontend::encode_message(&FrontendMsg::Sync, &mut self.send_buf);
274
275 frontend::encode_message(&FrontendMsg::Query(b"COMMIT"), &mut self.send_buf);
277
278 self.conn.send_raw(&self.send_buf).await?;
280
281 self.conn.drain_until_ready().await?; let (rows, _tag) = self.conn.collect_rows().await?; self.conn.drain_until_ready().await?; Ok(rows)
290 }
291
292 fn lookup_or_alloc(&mut self, sql: &str) -> (String, bool) {
294 if let Some(name) = self.stmt_cache.get(sql) {
295 return (name.clone(), false);
296 }
297
298 if self.stmt_cache.len() >= self.max_cache_size {
300 self.stmt_cache.clear();
302 }
303
304 let name = format!("s{}", self.stmt_counter);
305 self.stmt_counter += 1;
306 self.stmt_cache.insert(sql.to_string(), name.clone());
307 (name, true)
308 }
309
310 pub fn clear_cache(&mut self) {
312 self.stmt_cache.clear();
313 }
314
315 pub fn conn(&mut self) -> &mut WireConn {
317 &mut self.conn
318 }
319
320 pub fn conn_ref(&self) -> &WireConn {
322 &self.conn
323 }
324}