qail_pg/driver/
pipeline.rs1use super::{PgConnection, PgError, PgResult};
14use crate::protocol::{AstEncoder, BackendMessage, PgEncoder};
15use bytes::BytesMut;
16use tokio::io::AsyncWriteExt;
17
18impl PgConnection {
19 pub async fn query_pipeline(
21 &mut self,
22 queries: &[(&str, &[Option<Vec<u8>>])],
23 ) -> PgResult<Vec<Vec<Vec<Option<Vec<u8>>>>>> {
24 let mut buf = BytesMut::new();
26 for (sql, params) in queries {
27 buf.extend_from_slice(&PgEncoder::encode_extended_query(sql, params)
28 .map_err(|e| PgError::Encode(e.to_string()))?);
29 }
30
31 self.stream.write_all(&buf).await?;
33
34 let mut all_results: Vec<Vec<Vec<Option<Vec<u8>>>>> = Vec::with_capacity(queries.len());
36 let mut current_rows: Vec<Vec<Option<Vec<u8>>>> = Vec::new();
37 let mut queries_completed = 0;
38
39 loop {
40 let msg = self.recv().await?;
41 match msg {
42 BackendMessage::ParseComplete | BackendMessage::BindComplete => {}
43 BackendMessage::RowDescription(_) => {}
44 BackendMessage::DataRow(data) => {
45 current_rows.push(data);
46 }
47 BackendMessage::CommandComplete(_) => {
48 all_results.push(std::mem::take(&mut current_rows));
49 queries_completed += 1;
50 }
51 BackendMessage::NoData => {
52 all_results.push(Vec::new());
53 queries_completed += 1;
54 }
55 BackendMessage::ReadyForQuery(_) => {
56 if queries_completed == queries.len() {
57 return Ok(all_results);
58 }
59 }
60 BackendMessage::ErrorResponse(err) => {
61 return Err(PgError::Query(err.message));
62 }
63 _ => {}
64 }
65 }
66 }
67
68 pub async fn pipeline_ast(
70 &mut self,
71 cmds: &[qail_core::ast::Qail],
72 ) -> PgResult<Vec<Vec<Vec<Option<Vec<u8>>>>>> {
73 let buf = AstEncoder::encode_batch(cmds);
74 self.stream.write_all(&buf).await?;
75
76 let mut all_results: Vec<Vec<Vec<Option<Vec<u8>>>>> = Vec::with_capacity(cmds.len());
77 let mut current_rows: Vec<Vec<Option<Vec<u8>>>> = Vec::new();
78 let mut queries_completed = 0;
79
80 loop {
81 let msg = self.recv().await?;
82 match msg {
83 BackendMessage::ParseComplete | BackendMessage::BindComplete => {}
84 BackendMessage::RowDescription(_) => {}
85 BackendMessage::DataRow(data) => {
86 current_rows.push(data);
87 }
88 BackendMessage::CommandComplete(_) => {
89 all_results.push(std::mem::take(&mut current_rows));
90 queries_completed += 1;
91 }
92 BackendMessage::NoData => {
93 all_results.push(Vec::new());
94 queries_completed += 1;
95 }
96 BackendMessage::ReadyForQuery(_) => {
97 if queries_completed == cmds.len() {
98 return Ok(all_results);
99 }
100 }
101 BackendMessage::ErrorResponse(err) => {
102 return Err(PgError::Query(err.message));
103 }
104 _ => {}
105 }
106 }
107 }
108
109 pub async fn pipeline_ast_fast(&mut self, cmds: &[qail_core::ast::Qail]) -> PgResult<usize> {
111 let buf = AstEncoder::encode_batch(cmds);
112
113 self.stream.write_all(&buf).await?;
114 self.stream.flush().await?;
115
116 let mut queries_completed = 0;
117
118 loop {
119 let msg_type = self.recv_msg_type_fast().await?;
120 match msg_type {
121 b'C' | b'n' => queries_completed += 1,
122 b'Z' => {
123 if queries_completed == cmds.len() {
124 return Ok(queries_completed);
125 }
126 }
127 _ => {}
128 }
129 }
130 }
131
132 #[inline]
134 pub async fn pipeline_bytes_fast(
135 &mut self,
136 wire_bytes: &[u8],
137 expected_queries: usize,
138 ) -> PgResult<usize> {
139 self.stream.write_all(wire_bytes).await?;
140 self.stream.flush().await?;
141
142 let mut queries_completed = 0;
143
144 loop {
145 let msg_type = self.recv_msg_type_fast().await?;
146 match msg_type {
147 b'C' | b'n' => queries_completed += 1,
148 b'Z' => {
149 if queries_completed == expected_queries {
150 return Ok(queries_completed);
151 }
152 }
153 _ => {}
154 }
155 }
156 }
157
158 #[inline]
160 pub async fn pipeline_simple_fast(
161 &mut self,
162 cmds: &[qail_core::ast::Qail],
163 ) -> PgResult<usize> {
164 let buf = AstEncoder::encode_batch_simple(cmds);
165
166 self.stream.write_all(&buf).await?;
167 self.stream.flush().await?;
168
169 let mut queries_completed = 0;
170
171 loop {
172 let msg_type = self.recv_msg_type_fast().await?;
173 match msg_type {
174 b'C' => queries_completed += 1,
175 b'Z' => {
176 if queries_completed == cmds.len() {
177 return Ok(queries_completed);
178 }
179 }
180 _ => {}
181 }
182 }
183 }
184
185 #[inline]
187 pub async fn pipeline_simple_bytes_fast(
188 &mut self,
189 wire_bytes: &[u8],
190 expected_queries: usize,
191 ) -> PgResult<usize> {
192 self.stream.write_all(wire_bytes).await?;
193 self.stream.flush().await?;
194
195 let mut queries_completed = 0;
196
197 loop {
198 let msg_type = self.recv_msg_type_fast().await?;
199 match msg_type {
200 b'C' => queries_completed += 1,
201 b'Z' => {
202 if queries_completed == expected_queries {
203 return Ok(queries_completed);
204 }
205 }
206 _ => {}
207 }
208 }
209 }
210
211 #[inline]
216 pub async fn pipeline_ast_cached(
217 &mut self,
218 cmds: &[qail_core::ast::Qail],
219 ) -> PgResult<usize> {
220 if cmds.is_empty() {
221 return Ok(0);
222 }
223
224 let mut buf = BytesMut::with_capacity(cmds.len() * 64);
225
226 for cmd in cmds {
227 let (sql, params) = AstEncoder::encode_cmd_sql(cmd);
228 let stmt_name = Self::sql_to_stmt_name(&sql);
229
230 if !self.prepared_statements.contains_key(&stmt_name) {
231 buf.extend(PgEncoder::encode_parse(&stmt_name, &sql, &[]));
232 self.prepared_statements.insert(stmt_name.clone(), sql);
233 }
234
235 buf.extend_from_slice(&PgEncoder::encode_bind("", &stmt_name, ¶ms)
236 .map_err(|e| PgError::Encode(e.to_string()))?);
237 buf.extend(PgEncoder::encode_execute("", 0));
238 }
239
240 buf.extend(PgEncoder::encode_sync());
241
242 self.stream.write_all(&buf).await?;
243 self.stream.flush().await?;
244
245 let mut queries_completed = 0;
246
247 loop {
248 let msg_type = self.recv_msg_type_fast().await?;
249 match msg_type {
250 b'C' | b'n' => queries_completed += 1,
251 b'Z' => {
252 if queries_completed == cmds.len() {
253 return Ok(queries_completed);
254 }
255 }
256 _ => {}
257 }
258 }
259 }
260
261 #[inline]
276 pub async fn pipeline_prepared_fast(
277 &mut self,
278 stmt: &super::PreparedStatement,
279 params_batch: &[Vec<Option<Vec<u8>>>],
280 ) -> PgResult<usize> {
281 if params_batch.is_empty() {
282 return Ok(0);
283 }
284
285 let mut buf = BytesMut::with_capacity(params_batch.len() * 64);
287
288 let is_new = !self.prepared_statements.contains_key(&stmt.name);
289
290 if is_new {
291 return Err(PgError::Query(
292 "Statement not prepared. Call prepare() first.".to_string(),
293 ));
294 }
295
296 for params in params_batch {
298 PgEncoder::encode_bind_to(&mut buf, &stmt.name, params)
299 .map_err(|e| PgError::Encode(e.to_string()))?;
300 PgEncoder::encode_execute_to(&mut buf);
301 }
302
303 PgEncoder::encode_sync_to(&mut buf);
304
305 self.stream.write_all(&buf).await?;
306 self.stream.flush().await?;
307
308 let mut queries_completed = 0;
309
310 loop {
311 let msg_type = self.recv_msg_type_fast().await?;
312 match msg_type {
313 b'C' | b'n' => queries_completed += 1,
314 b'Z' => {
315 if queries_completed == params_batch.len() {
316 return Ok(queries_completed);
317 }
318 }
319 _ => {}
320 }
321 }
322 }
323
324 pub async fn prepare(&mut self, sql: &str) -> PgResult<super::PreparedStatement> {
327 use super::prepared::sql_bytes_to_stmt_name;
328
329 let stmt_name = sql_bytes_to_stmt_name(sql.as_bytes());
330
331 if !self.prepared_statements.contains_key(&stmt_name) {
332 let mut buf = BytesMut::with_capacity(sql.len() + 32);
333 buf.extend(PgEncoder::encode_parse(&stmt_name, sql, &[]));
334 buf.extend(PgEncoder::encode_sync());
335
336 self.stream.write_all(&buf).await?;
337 self.stream.flush().await?;
338
339 loop {
341 let msg_type = self.recv_msg_type_fast().await?;
342 match msg_type {
343 b'1' => {
344 self.prepared_statements
346 .insert(stmt_name.clone(), sql.to_string());
347 }
348 b'Z' => break, _ => {}
350 }
351 }
352 }
353
354 Ok(super::PreparedStatement {
355 name: stmt_name,
356 param_count: sql.matches('$').count(),
357 })
358 }
359
360 pub async fn pipeline_prepared_results(
362 &mut self,
363 stmt: &super::PreparedStatement,
364 params_batch: &[Vec<Option<Vec<u8>>>],
365 ) -> PgResult<Vec<Vec<Vec<Option<Vec<u8>>>>>> {
366 if params_batch.is_empty() {
367 return Ok(Vec::new());
368 }
369
370 if !self.prepared_statements.contains_key(&stmt.name) {
371 return Err(PgError::Query(
372 "Statement not prepared. Call prepare() first.".to_string(),
373 ));
374 }
375
376 let mut buf = BytesMut::with_capacity(params_batch.len() * 64);
377
378 for params in params_batch {
379 PgEncoder::encode_bind_to(&mut buf, &stmt.name, params)
380 .map_err(|e| PgError::Encode(e.to_string()))?;
381 PgEncoder::encode_execute_to(&mut buf);
382 }
383
384 PgEncoder::encode_sync_to(&mut buf);
385
386 self.stream.write_all(&buf).await?;
387 self.stream.flush().await?;
388
389 let mut all_results: Vec<Vec<Vec<Option<Vec<u8>>>>> =
391 Vec::with_capacity(params_batch.len());
392 let mut current_rows: Vec<Vec<Option<Vec<u8>>>> = Vec::new();
393
394 loop {
395 let (msg_type, data) = self.recv_with_data_fast().await?;
396
397 match msg_type {
398 b'2' => {} b'T' => {} b'D' => {
401 if let Some(row) = data {
403 current_rows.push(row);
404 }
405 }
406 b'C' => {
407 all_results.push(std::mem::take(&mut current_rows));
409 }
410 b'n' => {
411 all_results.push(Vec::new());
413 }
414 b'Z' => {
415 if all_results.len() == params_batch.len() {
417 return Ok(all_results);
418 }
419 }
420 _ => {}
421 }
422 }
423 }
424
425 pub async fn pipeline_prepared_zerocopy(
427 &mut self,
428 stmt: &super::PreparedStatement,
429 params_batch: &[Vec<Option<Vec<u8>>>],
430 ) -> PgResult<Vec<Vec<Vec<Option<bytes::Bytes>>>>> {
431 if params_batch.is_empty() {
432 return Ok(Vec::new());
433 }
434
435 if !self.prepared_statements.contains_key(&stmt.name) {
436 return Err(PgError::Query(
437 "Statement not prepared. Call prepare() first.".to_string(),
438 ));
439 }
440
441 let mut buf = BytesMut::with_capacity(params_batch.len() * 64);
442
443 for params in params_batch {
444 PgEncoder::encode_bind_to(&mut buf, &stmt.name, params)
445 .map_err(|e| PgError::Encode(e.to_string()))?;
446 PgEncoder::encode_execute_to(&mut buf);
447 }
448
449 PgEncoder::encode_sync_to(&mut buf);
450
451 self.stream.write_all(&buf).await?;
452 self.stream.flush().await?;
453
454 let mut all_results: Vec<Vec<Vec<Option<bytes::Bytes>>>> =
456 Vec::with_capacity(params_batch.len());
457 let mut current_rows: Vec<Vec<Option<bytes::Bytes>>> = Vec::new();
458
459 loop {
460 let (msg_type, data) = self.recv_data_zerocopy().await?;
461
462 match msg_type {
463 b'2' => {} b'T' => {} b'D' => {
466 if let Some(row) = data {
468 current_rows.push(row);
469 }
470 }
471 b'C' => {
472 all_results.push(std::mem::take(&mut current_rows));
474 }
475 b'n' => {
476 all_results.push(Vec::new());
478 }
479 b'Z' => {
480 if all_results.len() == params_batch.len() {
482 return Ok(all_results);
483 }
484 }
485 _ => {}
486 }
487 }
488 }
489
490 pub async fn pipeline_prepared_ultra(
492 &mut self,
493 stmt: &super::PreparedStatement,
494 params_batch: &[Vec<Option<Vec<u8>>>],
495 ) -> PgResult<Vec<Vec<(bytes::Bytes, bytes::Bytes)>>> {
496 if params_batch.is_empty() {
497 return Ok(Vec::new());
498 }
499
500 if !self.prepared_statements.contains_key(&stmt.name) {
501 return Err(PgError::Query(
502 "Statement not prepared. Call prepare() first.".to_string(),
503 ));
504 }
505
506 let mut buf = BytesMut::with_capacity(params_batch.len() * 64);
507
508 for params in params_batch {
509 PgEncoder::encode_bind_to(&mut buf, &stmt.name, params)
510 .map_err(|e| PgError::Encode(e.to_string()))?;
511 PgEncoder::encode_execute_to(&mut buf);
512 }
513
514 PgEncoder::encode_sync_to(&mut buf);
515
516 self.stream.write_all(&buf).await?;
517 self.stream.flush().await?;
518
519 let mut all_results: Vec<Vec<(bytes::Bytes, bytes::Bytes)>> =
521 Vec::with_capacity(params_batch.len());
522 let mut current_rows: Vec<(bytes::Bytes, bytes::Bytes)> = Vec::with_capacity(16);
523
524 loop {
525 let (msg_type, data) = self.recv_data_ultra().await?;
526
527 match msg_type {
528 b'2' | b'T' => {} b'D' => {
530 if let Some(row) = data {
531 current_rows.push(row);
532 }
533 }
534 b'C' => {
535 all_results.push(std::mem::take(&mut current_rows));
536 current_rows = Vec::with_capacity(16);
537 }
538 b'n' => {
539 all_results.push(Vec::new());
540 }
541 b'Z' => {
542 if all_results.len() == params_batch.len() {
543 return Ok(all_results);
544 }
545 }
546 _ => {}
547 }
548 }
549 }
550}