qail_pg/driver/
pipeline.rs1use super::{PgConnection, PgError, PgResult};
14use crate::protocol::{AstEncoder, BackendMessage, PgEncoder};
15use bytes::BytesMut;
16use tokio::io::AsyncWriteExt;
17
18impl PgConnection {
19 pub async fn query_pipeline(
21 &mut self,
22 queries: &[(&str, &[Option<Vec<u8>>])],
23 ) -> PgResult<Vec<Vec<Vec<Option<Vec<u8>>>>>> {
24 let mut buf = BytesMut::new();
26 for (sql, params) in queries {
27 buf.extend_from_slice(&PgEncoder::encode_extended_query(sql, params)
28 .map_err(|e| PgError::Encode(e.to_string()))?);
29 }
30
31 self.stream.write_all(&buf).await?;
33
34 let mut all_results: Vec<Vec<Vec<Option<Vec<u8>>>>> = Vec::with_capacity(queries.len());
36 let mut current_rows: Vec<Vec<Option<Vec<u8>>>> = Vec::new();
37 let mut queries_completed = 0;
38
39 loop {
40 let msg = self.recv().await?;
41 match msg {
42 BackendMessage::ParseComplete | BackendMessage::BindComplete => {}
43 BackendMessage::RowDescription(_) => {}
44 BackendMessage::DataRow(data) => {
45 current_rows.push(data);
46 }
47 BackendMessage::CommandComplete(_) => {
48 all_results.push(std::mem::take(&mut current_rows));
49 queries_completed += 1;
50 }
51 BackendMessage::NoData => {
52 all_results.push(Vec::new());
53 queries_completed += 1;
54 }
55 BackendMessage::ReadyForQuery(_) => {
56 if queries_completed == queries.len() {
57 return Ok(all_results);
58 }
59 }
60 BackendMessage::ErrorResponse(err) => {
61 return Err(PgError::Query(err.message));
62 }
63 _ => {}
64 }
65 }
66 }
67
68 pub async fn pipeline_ast(
70 &mut self,
71 cmds: &[qail_core::ast::Qail],
72 ) -> PgResult<Vec<Vec<Vec<Option<Vec<u8>>>>>> {
73 let buf = AstEncoder::encode_batch(cmds).map_err(|e| PgError::Encode(e.to_string()))?;
74 self.stream.write_all(&buf).await?;
75
76 let mut all_results: Vec<Vec<Vec<Option<Vec<u8>>>>> = Vec::with_capacity(cmds.len());
77 let mut current_rows: Vec<Vec<Option<Vec<u8>>>> = Vec::new();
78 let mut queries_completed = 0;
79
80 loop {
81 let msg = self.recv().await?;
82 match msg {
83 BackendMessage::ParseComplete | BackendMessage::BindComplete => {}
84 BackendMessage::RowDescription(_) => {}
85 BackendMessage::DataRow(data) => {
86 current_rows.push(data);
87 }
88 BackendMessage::CommandComplete(_) => {
89 all_results.push(std::mem::take(&mut current_rows));
90 queries_completed += 1;
91 }
92 BackendMessage::NoData => {
93 all_results.push(Vec::new());
94 queries_completed += 1;
95 }
96 BackendMessage::ReadyForQuery(_) => {
97 if queries_completed == cmds.len() {
98 return Ok(all_results);
99 }
100 }
101 BackendMessage::ErrorResponse(err) => {
102 return Err(PgError::Query(err.message));
103 }
104 _ => {}
105 }
106 }
107 }
108
109 pub async fn pipeline_ast_fast(&mut self, cmds: &[qail_core::ast::Qail]) -> PgResult<usize> {
111 let buf = AstEncoder::encode_batch(cmds).map_err(|e| PgError::Encode(e.to_string()))?;
112
113 self.stream.write_all(&buf).await?;
114 self.stream.flush().await?;
115
116 let mut queries_completed = 0;
117
118 loop {
119 let msg_type = self.recv_msg_type_fast().await?;
120 match msg_type {
121 b'C' | b'n' => queries_completed += 1,
122 b'Z' => {
123 if queries_completed == cmds.len() {
124 return Ok(queries_completed);
125 }
126 }
127 _ => {}
128 }
129 }
130 }
131
132 #[inline]
134 pub async fn pipeline_bytes_fast(
135 &mut self,
136 wire_bytes: &[u8],
137 expected_queries: usize,
138 ) -> PgResult<usize> {
139 self.stream.write_all(wire_bytes).await?;
140 self.stream.flush().await?;
141
142 let mut queries_completed = 0;
143
144 loop {
145 let msg_type = self.recv_msg_type_fast().await?;
146 match msg_type {
147 b'C' | b'n' => queries_completed += 1,
148 b'Z' => {
149 if queries_completed == expected_queries {
150 return Ok(queries_completed);
151 }
152 }
153 _ => {}
154 }
155 }
156 }
157
158 #[inline]
160 pub async fn pipeline_simple_fast(
161 &mut self,
162 cmds: &[qail_core::ast::Qail],
163 ) -> PgResult<usize> {
164 let buf = AstEncoder::encode_batch_simple(cmds).map_err(|e| PgError::Encode(e.to_string()))?;
165 self.stream.write_all(&buf).await?;
166 self.stream.flush().await?;
167
168 let mut queries_completed = 0;
169
170 loop {
171 let msg_type = self.recv_msg_type_fast().await?;
172 match msg_type {
173 b'C' => queries_completed += 1,
174 b'Z' => {
175 if queries_completed == cmds.len() {
176 return Ok(queries_completed);
177 }
178 }
179 _ => {}
180 }
181 }
182 }
183
184 #[inline]
186 pub async fn pipeline_simple_bytes_fast(
187 &mut self,
188 wire_bytes: &[u8],
189 expected_queries: usize,
190 ) -> PgResult<usize> {
191 self.stream.write_all(wire_bytes).await?;
192 self.stream.flush().await?;
193
194 let mut queries_completed = 0;
195
196 loop {
197 let msg_type = self.recv_msg_type_fast().await?;
198 match msg_type {
199 b'C' => queries_completed += 1,
200 b'Z' => {
201 if queries_completed == expected_queries {
202 return Ok(queries_completed);
203 }
204 }
205 _ => {}
206 }
207 }
208 }
209
210 #[inline]
215 pub async fn pipeline_ast_cached(
216 &mut self,
217 cmds: &[qail_core::ast::Qail],
218 ) -> PgResult<usize> {
219 if cmds.is_empty() {
220 return Ok(0);
221 }
222
223 let mut buf = BytesMut::with_capacity(cmds.len() * 64);
224
225 for cmd in cmds {
226 let (sql, params) = AstEncoder::encode_cmd_sql(cmd).map_err(|e| PgError::Encode(e.to_string()))?;
227 let stmt_name = Self::sql_to_stmt_name(&sql);
228
229 if !self.prepared_statements.contains_key(&stmt_name) {
230 self.evict_prepared_if_full();
231 buf.extend(PgEncoder::encode_parse(&stmt_name, &sql, &[]));
232 self.prepared_statements.insert(stmt_name.clone(), sql);
233 }
234
235 buf.extend_from_slice(&PgEncoder::encode_bind("", &stmt_name, ¶ms)
236 .map_err(|e| PgError::Encode(e.to_string()))?);
237 buf.extend(PgEncoder::encode_execute("", 0));
238 }
239
240 buf.extend(PgEncoder::encode_sync());
241
242 self.stream.write_all(&buf).await?;
243 self.stream.flush().await?;
244
245 let mut queries_completed = 0;
246
247 loop {
248 let msg_type = self.recv_msg_type_fast().await?;
249 match msg_type {
250 b'C' | b'n' => queries_completed += 1,
251 b'Z' => {
252 if queries_completed == cmds.len() {
253 return Ok(queries_completed);
254 }
255 }
256 _ => {}
257 }
258 }
259 }
260
261 #[inline]
276 pub async fn pipeline_prepared_fast(
277 &mut self,
278 stmt: &super::PreparedStatement,
279 params_batch: &[Vec<Option<Vec<u8>>>],
280 ) -> PgResult<usize> {
281 if params_batch.is_empty() {
282 return Ok(0);
283 }
284
285 let mut buf = BytesMut::with_capacity(params_batch.len() * 64);
287
288 let is_new = !self.prepared_statements.contains_key(&stmt.name);
289
290 if is_new {
291 return Err(PgError::Query(
292 "Statement not prepared. Call prepare() first.".to_string(),
293 ));
294 }
295
296 for params in params_batch {
298 PgEncoder::encode_bind_to(&mut buf, &stmt.name, params)
299 .map_err(|e| PgError::Encode(e.to_string()))?;
300 PgEncoder::encode_execute_to(&mut buf);
301 }
302
303 PgEncoder::encode_sync_to(&mut buf);
304
305 self.stream.write_all(&buf).await?;
306 self.stream.flush().await?;
307
308 let mut queries_completed = 0;
309
310 loop {
311 let msg_type = self.recv_msg_type_fast().await?;
312 match msg_type {
313 b'C' | b'n' => queries_completed += 1,
314 b'Z' => {
315 if queries_completed == params_batch.len() {
316 return Ok(queries_completed);
317 }
318 }
319 _ => {}
320 }
321 }
322 }
323
324 pub async fn prepare(&mut self, sql: &str) -> PgResult<super::PreparedStatement> {
327 use super::prepared::sql_bytes_to_stmt_name;
328
329 let stmt_name = sql_bytes_to_stmt_name(sql.as_bytes());
330
331 if !self.prepared_statements.contains_key(&stmt_name) {
332 self.evict_prepared_if_full();
333 let mut buf = BytesMut::with_capacity(sql.len() + 32);
334 buf.extend(PgEncoder::encode_parse(&stmt_name, sql, &[]));
335 buf.extend(PgEncoder::encode_sync());
336
337 self.stream.write_all(&buf).await?;
338 self.stream.flush().await?;
339
340 loop {
342 let msg_type = self.recv_msg_type_fast().await?;
343 match msg_type {
344 b'1' => {
345 self.prepared_statements
347 .insert(stmt_name.clone(), sql.to_string());
348 }
349 b'Z' => break, _ => {}
351 }
352 }
353 }
354
355 Ok(super::PreparedStatement {
356 name: stmt_name,
357 param_count: sql.matches('$').count(),
358 })
359 }
360
361 pub async fn pipeline_prepared_results(
363 &mut self,
364 stmt: &super::PreparedStatement,
365 params_batch: &[Vec<Option<Vec<u8>>>],
366 ) -> PgResult<Vec<Vec<Vec<Option<Vec<u8>>>>>> {
367 if params_batch.is_empty() {
368 return Ok(Vec::new());
369 }
370
371 if !self.prepared_statements.contains_key(&stmt.name) {
372 return Err(PgError::Query(
373 "Statement not prepared. Call prepare() first.".to_string(),
374 ));
375 }
376
377 let mut buf = BytesMut::with_capacity(params_batch.len() * 64);
378
379 for params in params_batch {
380 PgEncoder::encode_bind_to(&mut buf, &stmt.name, params)
381 .map_err(|e| PgError::Encode(e.to_string()))?;
382 PgEncoder::encode_execute_to(&mut buf);
383 }
384
385 PgEncoder::encode_sync_to(&mut buf);
386
387 self.stream.write_all(&buf).await?;
388 self.stream.flush().await?;
389
390 let mut all_results: Vec<Vec<Vec<Option<Vec<u8>>>>> =
392 Vec::with_capacity(params_batch.len());
393 let mut current_rows: Vec<Vec<Option<Vec<u8>>>> = Vec::new();
394
395 loop {
396 let (msg_type, data) = self.recv_with_data_fast().await?;
397
398 match msg_type {
399 b'2' => {} b'T' => {} b'D' => {
402 if let Some(row) = data {
404 current_rows.push(row);
405 }
406 }
407 b'C' => {
408 all_results.push(std::mem::take(&mut current_rows));
410 }
411 b'n' => {
412 all_results.push(Vec::new());
414 }
415 b'Z' => {
416 if all_results.len() == params_batch.len() {
418 return Ok(all_results);
419 }
420 }
421 _ => {}
422 }
423 }
424 }
425
426 pub async fn pipeline_prepared_zerocopy(
428 &mut self,
429 stmt: &super::PreparedStatement,
430 params_batch: &[Vec<Option<Vec<u8>>>],
431 ) -> PgResult<Vec<Vec<Vec<Option<bytes::Bytes>>>>> {
432 if params_batch.is_empty() {
433 return Ok(Vec::new());
434 }
435
436 if !self.prepared_statements.contains_key(&stmt.name) {
437 return Err(PgError::Query(
438 "Statement not prepared. Call prepare() first.".to_string(),
439 ));
440 }
441
442 let mut buf = BytesMut::with_capacity(params_batch.len() * 64);
443
444 for params in params_batch {
445 PgEncoder::encode_bind_to(&mut buf, &stmt.name, params)
446 .map_err(|e| PgError::Encode(e.to_string()))?;
447 PgEncoder::encode_execute_to(&mut buf);
448 }
449
450 PgEncoder::encode_sync_to(&mut buf);
451
452 self.stream.write_all(&buf).await?;
453 self.stream.flush().await?;
454
455 let mut all_results: Vec<Vec<Vec<Option<bytes::Bytes>>>> =
457 Vec::with_capacity(params_batch.len());
458 let mut current_rows: Vec<Vec<Option<bytes::Bytes>>> = Vec::new();
459
460 loop {
461 let (msg_type, data) = self.recv_data_zerocopy().await?;
462
463 match msg_type {
464 b'2' => {} b'T' => {} b'D' => {
467 if let Some(row) = data {
469 current_rows.push(row);
470 }
471 }
472 b'C' => {
473 all_results.push(std::mem::take(&mut current_rows));
475 }
476 b'n' => {
477 all_results.push(Vec::new());
479 }
480 b'Z' => {
481 if all_results.len() == params_batch.len() {
483 return Ok(all_results);
484 }
485 }
486 _ => {}
487 }
488 }
489 }
490
491 pub async fn pipeline_prepared_ultra(
493 &mut self,
494 stmt: &super::PreparedStatement,
495 params_batch: &[Vec<Option<Vec<u8>>>],
496 ) -> PgResult<Vec<Vec<(bytes::Bytes, bytes::Bytes)>>> {
497 if params_batch.is_empty() {
498 return Ok(Vec::new());
499 }
500
501 if !self.prepared_statements.contains_key(&stmt.name) {
502 return Err(PgError::Query(
503 "Statement not prepared. Call prepare() first.".to_string(),
504 ));
505 }
506
507 let mut buf = BytesMut::with_capacity(params_batch.len() * 64);
508
509 for params in params_batch {
510 PgEncoder::encode_bind_to(&mut buf, &stmt.name, params)
511 .map_err(|e| PgError::Encode(e.to_string()))?;
512 PgEncoder::encode_execute_to(&mut buf);
513 }
514
515 PgEncoder::encode_sync_to(&mut buf);
516
517 self.stream.write_all(&buf).await?;
518 self.stream.flush().await?;
519
520 let mut all_results: Vec<Vec<(bytes::Bytes, bytes::Bytes)>> =
522 Vec::with_capacity(params_batch.len());
523 let mut current_rows: Vec<(bytes::Bytes, bytes::Bytes)> = Vec::with_capacity(16);
524
525 loop {
526 let (msg_type, data) = self.recv_data_ultra().await?;
527
528 match msg_type {
529 b'2' | b'T' => {} b'D' => {
531 if let Some(row) = data {
532 current_rows.push(row);
533 }
534 }
535 b'C' => {
536 all_results.push(std::mem::take(&mut current_rows));
537 current_rows = Vec::with_capacity(16);
538 }
539 b'n' => {
540 all_results.push(Vec::new());
541 }
542 b'Z' => {
543 if all_results.len() == params_batch.len() {
544 return Ok(all_results);
545 }
546 }
547 _ => {}
548 }
549 }
550 }
551}