qail_pg/driver/
pipeline.rs1use super::{PgConnection, PgError, PgResult};
14use crate::protocol::{AstEncoder, BackendMessage, PgEncoder};
15use bytes::BytesMut;
16use tokio::io::AsyncWriteExt;
17
18impl PgConnection {
19 pub async fn query_pipeline(
21 &mut self,
22 queries: &[(&str, &[Option<Vec<u8>>])],
23 ) -> PgResult<Vec<Vec<Vec<Option<Vec<u8>>>>>> {
24 let mut buf = BytesMut::new();
26 for (sql, params) in queries {
27 buf.extend_from_slice(
28 &PgEncoder::encode_extended_query(sql, params)
29 .map_err(|e| PgError::Encode(e.to_string()))?,
30 );
31 }
32
33 self.stream.write_all(&buf).await?;
35
36 let mut all_results: Vec<Vec<Vec<Option<Vec<u8>>>>> = Vec::with_capacity(queries.len());
38 let mut current_rows: Vec<Vec<Option<Vec<u8>>>> = Vec::new();
39 let mut queries_completed = 0;
40
41 loop {
42 let msg = self.recv().await?;
43 match msg {
44 BackendMessage::ParseComplete | BackendMessage::BindComplete => {}
45 BackendMessage::RowDescription(_) => {}
46 BackendMessage::DataRow(data) => {
47 current_rows.push(data);
48 }
49 BackendMessage::CommandComplete(_) => {
50 all_results.push(std::mem::take(&mut current_rows));
51 queries_completed += 1;
52 }
53 BackendMessage::NoData => {
54 all_results.push(Vec::new());
55 queries_completed += 1;
56 }
57 BackendMessage::ReadyForQuery(_) => {
58 if queries_completed == queries.len() {
59 return Ok(all_results);
60 }
61 }
62 BackendMessage::ErrorResponse(err) => {
63 return Err(PgError::QueryServer(err.into()));
64 }
65 _ => {}
66 }
67 }
68 }
69
70 pub async fn pipeline_ast(
72 &mut self,
73 cmds: &[qail_core::ast::Qail],
74 ) -> PgResult<Vec<Vec<Vec<Option<Vec<u8>>>>>> {
75 let buf = AstEncoder::encode_batch(cmds).map_err(|e| PgError::Encode(e.to_string()))?;
76 self.stream.write_all(&buf).await?;
77
78 let mut all_results: Vec<Vec<Vec<Option<Vec<u8>>>>> = Vec::with_capacity(cmds.len());
79 let mut current_rows: Vec<Vec<Option<Vec<u8>>>> = Vec::new();
80 let mut queries_completed = 0;
81
82 loop {
83 let msg = self.recv().await?;
84 match msg {
85 BackendMessage::ParseComplete | BackendMessage::BindComplete => {}
86 BackendMessage::RowDescription(_) => {}
87 BackendMessage::DataRow(data) => {
88 current_rows.push(data);
89 }
90 BackendMessage::CommandComplete(_) => {
91 all_results.push(std::mem::take(&mut current_rows));
92 queries_completed += 1;
93 }
94 BackendMessage::NoData => {
95 all_results.push(Vec::new());
96 queries_completed += 1;
97 }
98 BackendMessage::ReadyForQuery(_) => {
99 if queries_completed == cmds.len() {
100 return Ok(all_results);
101 }
102 }
103 BackendMessage::ErrorResponse(err) => {
104 return Err(PgError::QueryServer(err.into()));
105 }
106 _ => {}
107 }
108 }
109 }
110
111 pub async fn pipeline_ast_fast(&mut self, cmds: &[qail_core::ast::Qail]) -> PgResult<usize> {
113 let buf = AstEncoder::encode_batch(cmds).map_err(|e| PgError::Encode(e.to_string()))?;
114
115 self.stream.write_all(&buf).await?;
116 self.stream.flush().await?;
117
118 let mut queries_completed = 0;
119
120 loop {
121 let msg_type = self.recv_msg_type_fast().await?;
122 match msg_type {
123 b'C' | b'n' => queries_completed += 1,
124 b'Z' => {
125 if queries_completed == cmds.len() {
126 return Ok(queries_completed);
127 }
128 }
129 _ => {}
130 }
131 }
132 }
133
134 #[inline]
136 pub async fn pipeline_bytes_fast(
137 &mut self,
138 wire_bytes: &[u8],
139 expected_queries: usize,
140 ) -> PgResult<usize> {
141 self.stream.write_all(wire_bytes).await?;
142 self.stream.flush().await?;
143
144 let mut queries_completed = 0;
145
146 loop {
147 let msg_type = self.recv_msg_type_fast().await?;
148 match msg_type {
149 b'C' | b'n' => queries_completed += 1,
150 b'Z' => {
151 if queries_completed == expected_queries {
152 return Ok(queries_completed);
153 }
154 }
155 _ => {}
156 }
157 }
158 }
159
160 #[inline]
162 pub async fn pipeline_simple_fast(&mut self, cmds: &[qail_core::ast::Qail]) -> PgResult<usize> {
163 let buf =
164 AstEncoder::encode_batch_simple(cmds).map_err(|e| PgError::Encode(e.to_string()))?;
165 self.stream.write_all(&buf).await?;
166 self.stream.flush().await?;
167
168 let mut queries_completed = 0;
169
170 loop {
171 let msg_type = self.recv_msg_type_fast().await?;
172 match msg_type {
173 b'C' => queries_completed += 1,
174 b'Z' => {
175 if queries_completed == cmds.len() {
176 return Ok(queries_completed);
177 }
178 }
179 _ => {}
180 }
181 }
182 }
183
184 #[inline]
186 pub async fn pipeline_simple_bytes_fast(
187 &mut self,
188 wire_bytes: &[u8],
189 expected_queries: usize,
190 ) -> PgResult<usize> {
191 self.stream.write_all(wire_bytes).await?;
192 self.stream.flush().await?;
193
194 let mut queries_completed = 0;
195
196 loop {
197 let msg_type = self.recv_msg_type_fast().await?;
198 match msg_type {
199 b'C' => queries_completed += 1,
200 b'Z' => {
201 if queries_completed == expected_queries {
202 return Ok(queries_completed);
203 }
204 }
205 _ => {}
206 }
207 }
208 }
209
210 #[inline]
215 pub async fn pipeline_ast_cached(&mut self, cmds: &[qail_core::ast::Qail]) -> PgResult<usize> {
216 if cmds.is_empty() {
217 return Ok(0);
218 }
219
220 let mut buf = BytesMut::with_capacity(cmds.len() * 64);
221
222 for cmd in cmds {
223 let (sql, params) =
224 AstEncoder::encode_cmd_sql(cmd).map_err(|e| PgError::Encode(e.to_string()))?;
225 let stmt_name = Self::sql_to_stmt_name(&sql);
226
227 if !self.prepared_statements.contains_key(&stmt_name) {
228 self.evict_prepared_if_full();
229 buf.extend(PgEncoder::encode_parse(&stmt_name, &sql, &[]));
230 self.prepared_statements.insert(stmt_name.clone(), sql);
231 }
232
233 buf.extend_from_slice(
234 &PgEncoder::encode_bind("", &stmt_name, ¶ms)
235 .map_err(|e| PgError::Encode(e.to_string()))?,
236 );
237 buf.extend(PgEncoder::encode_execute("", 0));
238 }
239
240 buf.extend(PgEncoder::encode_sync());
241
242 self.stream.write_all(&buf).await?;
243 self.stream.flush().await?;
244
245 let mut queries_completed = 0;
246
247 loop {
248 let msg_type = self.recv_msg_type_fast().await?;
249 match msg_type {
250 b'C' | b'n' => queries_completed += 1,
251 b'Z' => {
252 if queries_completed == cmds.len() {
253 return Ok(queries_completed);
254 }
255 }
256 _ => {}
257 }
258 }
259 }
260
261 #[inline]
276 pub async fn pipeline_prepared_fast(
277 &mut self,
278 stmt: &super::PreparedStatement,
279 params_batch: &[Vec<Option<Vec<u8>>>],
280 ) -> PgResult<usize> {
281 if params_batch.is_empty() {
282 return Ok(0);
283 }
284
285 let mut buf = BytesMut::with_capacity(params_batch.len() * 64);
287
288 let is_new = !self.prepared_statements.contains_key(&stmt.name);
289
290 if is_new {
291 return Err(PgError::Query(
292 "Statement not prepared. Call prepare() first.".to_string(),
293 ));
294 }
295
296 for params in params_batch {
298 PgEncoder::encode_bind_to(&mut buf, &stmt.name, params)
299 .map_err(|e| PgError::Encode(e.to_string()))?;
300 PgEncoder::encode_execute_to(&mut buf);
301 }
302
303 PgEncoder::encode_sync_to(&mut buf);
304
305 self.stream.write_all(&buf).await?;
306 self.stream.flush().await?;
307
308 let mut queries_completed = 0;
309
310 loop {
311 let msg_type = self.recv_msg_type_fast().await?;
312 match msg_type {
313 b'C' | b'n' => queries_completed += 1,
314 b'Z' => {
315 if queries_completed == params_batch.len() {
316 return Ok(queries_completed);
317 }
318 }
319 _ => {}
320 }
321 }
322 }
323
324 pub async fn prepare(&mut self, sql: &str) -> PgResult<super::PreparedStatement> {
327 use super::prepared::sql_bytes_to_stmt_name;
328
329 let stmt_name = sql_bytes_to_stmt_name(sql.as_bytes());
330
331 if !self.prepared_statements.contains_key(&stmt_name) {
332 self.evict_prepared_if_full();
333 let mut buf = BytesMut::with_capacity(sql.len() + 32);
334 buf.extend(PgEncoder::encode_parse(&stmt_name, sql, &[]));
335 buf.extend(PgEncoder::encode_sync());
336
337 self.stream.write_all(&buf).await?;
338 self.stream.flush().await?;
339
340 loop {
342 let msg_type = self.recv_msg_type_fast().await?;
343 match msg_type {
344 b'1' => {
345 self.prepared_statements
347 .insert(stmt_name.clone(), sql.to_string());
348 }
349 b'Z' => break, _ => {}
351 }
352 }
353 }
354
355 Ok(super::PreparedStatement {
356 name: stmt_name,
357 param_count: sql.matches('$').count(),
358 })
359 }
360
361 pub async fn pipeline_prepared_results(
363 &mut self,
364 stmt: &super::PreparedStatement,
365 params_batch: &[Vec<Option<Vec<u8>>>],
366 ) -> PgResult<Vec<Vec<Vec<Option<Vec<u8>>>>>> {
367 if params_batch.is_empty() {
368 return Ok(Vec::new());
369 }
370
371 if !self.prepared_statements.contains_key(&stmt.name) {
372 return Err(PgError::Query(
373 "Statement not prepared. Call prepare() first.".to_string(),
374 ));
375 }
376
377 let mut buf = BytesMut::with_capacity(params_batch.len() * 64);
378
379 for params in params_batch {
380 PgEncoder::encode_bind_to(&mut buf, &stmt.name, params)
381 .map_err(|e| PgError::Encode(e.to_string()))?;
382 PgEncoder::encode_execute_to(&mut buf);
383 }
384
385 PgEncoder::encode_sync_to(&mut buf);
386
387 self.stream.write_all(&buf).await?;
388 self.stream.flush().await?;
389
390 let mut all_results: Vec<Vec<Vec<Option<Vec<u8>>>>> =
392 Vec::with_capacity(params_batch.len());
393 let mut current_rows: Vec<Vec<Option<Vec<u8>>>> = Vec::new();
394
395 loop {
396 let (msg_type, data) = self.recv_with_data_fast().await?;
397
398 match msg_type {
399 b'2' => {} b'T' => {} b'D' => {
402 if let Some(row) = data {
404 current_rows.push(row);
405 }
406 }
407 b'C' => {
408 all_results.push(std::mem::take(&mut current_rows));
410 }
411 b'n' => {
412 all_results.push(Vec::new());
414 }
415 b'Z' => {
416 if all_results.len() == params_batch.len() {
418 return Ok(all_results);
419 }
420 }
421 _ => {}
422 }
423 }
424 }
425
426 pub async fn pipeline_prepared_zerocopy(
428 &mut self,
429 stmt: &super::PreparedStatement,
430 params_batch: &[Vec<Option<Vec<u8>>>],
431 ) -> PgResult<Vec<Vec<Vec<Option<bytes::Bytes>>>>> {
432 if params_batch.is_empty() {
433 return Ok(Vec::new());
434 }
435
436 if !self.prepared_statements.contains_key(&stmt.name) {
437 return Err(PgError::Query(
438 "Statement not prepared. Call prepare() first.".to_string(),
439 ));
440 }
441
442 let mut buf = BytesMut::with_capacity(params_batch.len() * 64);
443
444 for params in params_batch {
445 PgEncoder::encode_bind_to(&mut buf, &stmt.name, params)
446 .map_err(|e| PgError::Encode(e.to_string()))?;
447 PgEncoder::encode_execute_to(&mut buf);
448 }
449
450 PgEncoder::encode_sync_to(&mut buf);
451
452 self.stream.write_all(&buf).await?;
453 self.stream.flush().await?;
454
455 let mut all_results: Vec<Vec<Vec<Option<bytes::Bytes>>>> =
457 Vec::with_capacity(params_batch.len());
458 let mut current_rows: Vec<Vec<Option<bytes::Bytes>>> = Vec::new();
459
460 loop {
461 let (msg_type, data) = self.recv_data_zerocopy().await?;
462
463 match msg_type {
464 b'2' => {} b'T' => {} b'D' => {
467 if let Some(row) = data {
469 current_rows.push(row);
470 }
471 }
472 b'C' => {
473 all_results.push(std::mem::take(&mut current_rows));
475 }
476 b'n' => {
477 all_results.push(Vec::new());
479 }
480 b'Z' => {
481 if all_results.len() == params_batch.len() {
483 return Ok(all_results);
484 }
485 }
486 _ => {}
487 }
488 }
489 }
490
491 pub async fn pipeline_prepared_ultra(
493 &mut self,
494 stmt: &super::PreparedStatement,
495 params_batch: &[Vec<Option<Vec<u8>>>],
496 ) -> PgResult<Vec<Vec<(bytes::Bytes, bytes::Bytes)>>> {
497 if params_batch.is_empty() {
498 return Ok(Vec::new());
499 }
500
501 if !self.prepared_statements.contains_key(&stmt.name) {
502 return Err(PgError::Query(
503 "Statement not prepared. Call prepare() first.".to_string(),
504 ));
505 }
506
507 let mut buf = BytesMut::with_capacity(params_batch.len() * 64);
508
509 for params in params_batch {
510 PgEncoder::encode_bind_to(&mut buf, &stmt.name, params)
511 .map_err(|e| PgError::Encode(e.to_string()))?;
512 PgEncoder::encode_execute_to(&mut buf);
513 }
514
515 PgEncoder::encode_sync_to(&mut buf);
516
517 self.stream.write_all(&buf).await?;
518 self.stream.flush().await?;
519
520 let mut all_results: Vec<Vec<(bytes::Bytes, bytes::Bytes)>> =
522 Vec::with_capacity(params_batch.len());
523 let mut current_rows: Vec<(bytes::Bytes, bytes::Bytes)> = Vec::with_capacity(16);
524
525 loop {
526 let (msg_type, data) = self.recv_data_ultra().await?;
527
528 match msg_type {
529 b'2' | b'T' => {} b'D' => {
531 if let Some(row) = data {
532 current_rows.push(row);
533 }
534 }
535 b'C' => {
536 all_results.push(std::mem::take(&mut current_rows));
537 current_rows = Vec::with_capacity(16);
538 }
539 b'n' => {
540 all_results.push(Vec::new());
541 }
542 b'Z' => {
543 if all_results.len() == params_batch.len() {
544 return Ok(all_results);
545 }
546 }
547 _ => {}
548 }
549 }
550 }
551}