1use super::{
5 PgConnection, PgError, PgResult, is_ignorable_session_message, parse_affected_rows,
6 unexpected_backend_message,
7};
8use crate::protocol::{AstEncoder, BackendMessage, PgEncoder};
9use bytes::BytesMut;
10use qail_core::ast::{Action, Qail};
11use std::future::Future;
12
13fn quote_ident(ident: &str) -> String {
16 format!("\"{}\"", ident.replace('\0', "").replace('"', "\"\""))
17}
18
19fn parse_copy_text_row(line: &[u8]) -> Vec<String> {
20 let line = if line.ends_with(b"\r") {
21 &line[..line.len().saturating_sub(1)]
22 } else {
23 line
24 };
25 let text = String::from_utf8_lossy(line);
26 text.split('\t').map(|s| s.to_string()).collect()
27}
28
29fn drain_copy_text_rows<F>(pending: &mut Vec<u8>, chunk: &[u8], on_row: &mut F) -> PgResult<()>
30where
31 F: FnMut(Vec<String>) -> PgResult<()>,
32{
33 pending.extend_from_slice(chunk);
34 while let Some(pos) = pending.iter().position(|&b| b == b'\n') {
35 let line = pending[..pos].to_vec();
36 pending.drain(..=pos);
37 on_row(parse_copy_text_row(&line))?;
38 }
39 Ok(())
40}
41
42fn flush_pending_copy_text_row<F>(pending: &mut Vec<u8>, on_row: &mut F) -> PgResult<()>
43where
44 F: FnMut(Vec<String>) -> PgResult<()>,
45{
46 if pending.is_empty() {
47 return Ok(());
48 }
49 let line = std::mem::take(pending);
50 on_row(parse_copy_text_row(&line))
51}
52
53impl PgConnection {
54 pub(crate) async fn copy_in_fast(
58 &mut self,
59 table: &str,
60 columns: &[String],
61 rows: &[Vec<qail_core::ast::Value>],
62 ) -> PgResult<u64> {
63 use crate::protocol::encode_copy_batch;
64
65 let cols: Vec<String> = columns.iter().map(|c| quote_ident(c)).collect();
66 let sql = format!(
67 "COPY {} ({}) FROM STDIN",
68 quote_ident(table),
69 cols.join(", ")
70 );
71
72 let bytes = PgEncoder::try_encode_query_string(&sql)?;
74 self.write_all_with_timeout(&bytes, "stream write").await?;
75
76 let mut startup_error: Option<PgError> = None;
78 loop {
79 let msg = self.recv().await?;
80 match msg {
81 BackendMessage::CopyInResponse { .. } => {
82 if let Some(err) = startup_error {
83 return Err(err);
84 }
85 break;
86 }
87 BackendMessage::ReadyForQuery(_) => {
88 return Err(startup_error.unwrap_or_else(|| {
89 PgError::Protocol(
90 "COPY IN failed before CopyInResponse (unexpected ReadyForQuery)"
91 .to_string(),
92 )
93 }));
94 }
95 BackendMessage::ErrorResponse(err) => {
96 if startup_error.is_none() {
97 startup_error = Some(PgError::QueryServer(err.into()));
98 }
99 }
100 msg if is_ignorable_session_message(&msg) => {}
101 other => {
102 return Err(unexpected_backend_message("copy-in startup", &other));
103 }
104 }
105 }
106
107 let batch_data = encode_copy_batch(rows);
109
110 self.send_copy_data(&batch_data).await?;
112
113 self.send_copy_done().await?;
115
116 let mut affected = 0u64;
118 let mut final_error: Option<PgError> = None;
119 let mut saw_command_complete = false;
120 loop {
121 let msg = self.recv().await?;
122 match msg {
123 BackendMessage::CommandComplete(tag) => {
124 if saw_command_complete {
125 return Err(PgError::Protocol(
126 "COPY IN received duplicate CommandComplete".to_string(),
127 ));
128 }
129 saw_command_complete = true;
130 if final_error.is_none() {
131 affected = parse_affected_rows(&tag);
132 }
133 }
134 BackendMessage::ReadyForQuery(_) => {
135 if let Some(err) = final_error {
136 return Err(err);
137 }
138 if !saw_command_complete {
139 return Err(PgError::Protocol(
140 "COPY IN completion missing CommandComplete before ReadyForQuery"
141 .to_string(),
142 ));
143 }
144 return Ok(affected);
145 }
146 BackendMessage::ErrorResponse(err) => {
147 if final_error.is_none() {
148 final_error = Some(PgError::QueryServer(err.into()));
149 }
150 }
151 msg if is_ignorable_session_message(&msg) => {}
152 other => {
153 return Err(unexpected_backend_message("copy-in completion", &other));
154 }
155 }
156 }
157 }
158
159 pub async fn copy_in_raw(
166 &mut self,
167 table: &str,
168 columns: &[String],
169 data: &[u8],
170 ) -> PgResult<u64> {
171 let cols: Vec<String> = columns.iter().map(|c| quote_ident(c)).collect();
172 let sql = format!(
173 "COPY {} ({}) FROM STDIN",
174 quote_ident(table),
175 cols.join(", ")
176 );
177
178 let bytes = PgEncoder::try_encode_query_string(&sql)?;
180 self.write_all_with_timeout(&bytes, "stream write").await?;
181
182 let mut startup_error: Option<PgError> = None;
184 loop {
185 let msg = self.recv().await?;
186 match msg {
187 BackendMessage::CopyInResponse { .. } => {
188 if let Some(err) = startup_error {
189 return Err(err);
190 }
191 break;
192 }
193 BackendMessage::ReadyForQuery(_) => {
194 return Err(startup_error.unwrap_or_else(|| {
195 PgError::Protocol(
196 "COPY IN failed before CopyInResponse (unexpected ReadyForQuery)"
197 .to_string(),
198 )
199 }));
200 }
201 BackendMessage::ErrorResponse(err) => {
202 if startup_error.is_none() {
203 startup_error = Some(PgError::QueryServer(err.into()));
204 }
205 }
206 msg if is_ignorable_session_message(&msg) => {}
207 other => {
208 return Err(unexpected_backend_message("copy-in raw startup", &other));
209 }
210 }
211 }
212
213 self.send_copy_data(data).await?;
215
216 self.send_copy_done().await?;
218
219 let mut affected = 0u64;
221 let mut final_error: Option<PgError> = None;
222 let mut saw_command_complete = false;
223 loop {
224 let msg = self.recv().await?;
225 match msg {
226 BackendMessage::CommandComplete(tag) => {
227 if saw_command_complete {
228 return Err(PgError::Protocol(
229 "COPY IN raw received duplicate CommandComplete".to_string(),
230 ));
231 }
232 saw_command_complete = true;
233 if final_error.is_none() {
234 affected = parse_affected_rows(&tag);
235 }
236 }
237 BackendMessage::ReadyForQuery(_) => {
238 if let Some(err) = final_error {
239 return Err(err);
240 }
241 if !saw_command_complete {
242 return Err(PgError::Protocol(
243 "COPY IN raw completion missing CommandComplete before ReadyForQuery"
244 .to_string(),
245 ));
246 }
247 return Ok(affected);
248 }
249 BackendMessage::ErrorResponse(err) => {
250 if final_error.is_none() {
251 final_error = Some(PgError::QueryServer(err.into()));
252 }
253 }
254 msg if is_ignorable_session_message(&msg) => {}
255 other => {
256 return Err(unexpected_backend_message("copy-in raw completion", &other));
257 }
258 }
259 }
260 }
261
262 pub(crate) async fn send_copy_data(&mut self, data: &[u8]) -> PgResult<()> {
264 let total_len = data
265 .len()
266 .checked_add(4)
267 .ok_or_else(|| PgError::Protocol("CopyData frame length overflow".to_string()))?;
268 let len = i32::try_from(total_len)
269 .map_err(|_| PgError::Protocol("CopyData frame exceeds i32::MAX".to_string()))?;
270
271 let mut buf = BytesMut::with_capacity(1 + 4 + data.len());
273 buf.extend_from_slice(b"d");
274 buf.extend_from_slice(&len.to_be_bytes());
275 buf.extend_from_slice(data);
276 self.write_all_with_timeout(&buf, "stream write").await?;
277 Ok(())
278 }
279
280 async fn send_copy_done(&mut self) -> PgResult<()> {
281 self.write_all_with_timeout(&[b'c', 0, 0, 0, 4], "stream write")
283 .await?;
284 Ok(())
285 }
286
287 async fn start_copy_out(&mut self, sql: &str, context: &str) -> PgResult<()> {
288 let bytes = PgEncoder::try_encode_query_string(sql)?;
289 self.write_all_with_timeout(&bytes, "stream write").await?;
290
291 let mut startup_error: Option<PgError> = None;
292 loop {
293 let msg = self.recv().await?;
294 match msg {
295 BackendMessage::CopyOutResponse { .. } => {
296 if let Some(err) = startup_error {
297 return Err(err);
298 }
299 return Ok(());
300 }
301 BackendMessage::ReadyForQuery(_) => {
302 return Err(startup_error.unwrap_or_else(|| {
303 PgError::Protocol(format!(
304 "{} failed before CopyOutResponse (unexpected ReadyForQuery)",
305 context
306 ))
307 }));
308 }
309 BackendMessage::ErrorResponse(err) => {
310 if startup_error.is_none() {
311 startup_error = Some(PgError::QueryServer(err.into()));
312 }
313 }
314 msg if is_ignorable_session_message(&msg) => {}
315 other => return Err(unexpected_backend_message(context, &other)),
316 }
317 }
318 }
319
320 async fn stream_copy_out_chunks<F, Fut>(
321 &mut self,
322 context: &str,
323 mut on_chunk: F,
324 ) -> PgResult<()>
325 where
326 F: FnMut(Vec<u8>) -> Fut,
327 Fut: Future<Output = PgResult<()>>,
328 {
329 let mut stream_error: Option<PgError> = None;
330 let mut callback_error: Option<PgError> = None;
331 let mut saw_copy_done = false;
332 let mut saw_command_complete = false;
333
334 loop {
335 let msg = self.recv().await?;
336 match msg {
337 BackendMessage::CopyData(chunk) => {
338 if saw_copy_done {
339 return Err(PgError::Protocol(format!(
340 "{} received CopyData after CopyDone",
341 context
342 )));
343 }
344 if stream_error.is_none()
345 && callback_error.is_none()
346 && let Err(e) = on_chunk(chunk).await
347 {
348 callback_error = Some(e);
349 }
350 }
351 BackendMessage::CopyDone => {
352 if saw_copy_done {
353 return Err(PgError::Protocol(format!(
354 "{} received duplicate CopyDone",
355 context
356 )));
357 }
358 saw_copy_done = true;
359 }
360 BackendMessage::CommandComplete(_) => {
361 if saw_command_complete {
362 return Err(PgError::Protocol(format!(
363 "{} received duplicate CommandComplete",
364 context
365 )));
366 }
367 saw_command_complete = true;
368 }
369 BackendMessage::ReadyForQuery(_) => {
370 if let Some(err) = stream_error {
371 return Err(err);
372 }
373 if let Some(err) = callback_error {
374 return Err(err);
375 }
376 if !saw_copy_done {
377 return Err(PgError::Protocol(format!(
378 "{} missing CopyDone before ReadyForQuery",
379 context
380 )));
381 }
382 if !saw_command_complete {
383 return Err(PgError::Protocol(format!(
384 "{} missing CommandComplete before ReadyForQuery",
385 context
386 )));
387 }
388 return Ok(());
389 }
390 BackendMessage::ErrorResponse(err) => {
391 if stream_error.is_none() {
392 stream_error = Some(PgError::QueryServer(err.into()));
393 }
394 }
395 msg if is_ignorable_session_message(&msg) => {}
396 other => return Err(unexpected_backend_message(context, &other)),
397 }
398 }
399 }
400
401 pub async fn copy_export(&mut self, cmd: &Qail) -> PgResult<Vec<Vec<String>>> {
411 let mut rows = Vec::new();
412 self.copy_export_stream_rows(cmd, |row| {
413 rows.push(row);
414 Ok(())
415 })
416 .await?;
417 Ok(rows)
418 }
419
420 pub async fn copy_export_stream_raw<F, Fut>(&mut self, cmd: &Qail, on_chunk: F) -> PgResult<()>
425 where
426 F: FnMut(Vec<u8>) -> Fut,
427 Fut: Future<Output = PgResult<()>>,
428 {
429 if cmd.action != Action::Export {
430 return Err(PgError::Query(
431 "copy_export requires Qail::Export action".to_string(),
432 ));
433 }
434
435 let (sql, _params) =
437 AstEncoder::encode_cmd_sql(cmd).map_err(|e| PgError::Encode(e.to_string()))?;
438
439 self.copy_out_raw_stream(&sql, on_chunk).await
440 }
441
442 pub async fn copy_export_stream_rows<F>(&mut self, cmd: &Qail, mut on_row: F) -> PgResult<()>
447 where
448 F: FnMut(Vec<String>) -> PgResult<()>,
449 {
450 let mut pending = Vec::new();
451 self.copy_export_stream_raw(cmd, |chunk| {
452 let res = drain_copy_text_rows(&mut pending, &chunk, &mut on_row);
453 std::future::ready(res)
454 })
455 .await?;
456 flush_pending_copy_text_row(&mut pending, &mut on_row)
457 }
458
459 pub(crate) async fn copy_out_raw(&mut self, sql: &str) -> PgResult<Vec<u8>> {
467 let mut data = Vec::new();
468 self.copy_out_raw_stream(sql, |chunk| {
469 data.extend_from_slice(&chunk);
470 std::future::ready(Ok(()))
471 })
472 .await?;
473 Ok(data)
474 }
475
476 pub(crate) async fn copy_out_raw_stream<F, Fut>(
481 &mut self,
482 sql: &str,
483 on_chunk: F,
484 ) -> PgResult<()>
485 where
486 F: FnMut(Vec<u8>) -> Fut,
487 Fut: Future<Output = PgResult<()>>,
488 {
489 self.start_copy_out(sql, "copy-out raw startup").await?;
490 self.stream_copy_out_chunks("copy-out raw stream", on_chunk)
491 .await
492 }
493}
494
495#[cfg(test)]
496mod tests {
497 use super::{drain_copy_text_rows, flush_pending_copy_text_row, parse_copy_text_row};
498 use crate::driver::{PgError, PgResult};
499
500 #[test]
501 fn parse_copy_text_row_splits_tabs() {
502 let row = parse_copy_text_row(b"a\tb\tc");
503 assert_eq!(row, vec!["a", "b", "c"]);
504 }
505
506 #[test]
507 fn parse_copy_text_row_trims_cr() {
508 let row = parse_copy_text_row(b"a\tb\r");
509 assert_eq!(row, vec!["a", "b"]);
510 }
511
512 #[test]
513 fn drain_copy_text_rows_handles_chunk_boundaries() {
514 let mut pending = Vec::new();
515 let mut rows: Vec<Vec<String>> = Vec::new();
516
517 drain_copy_text_rows(&mut pending, b"a\tb\nc", &mut |row: Vec<String>| {
518 rows.push(row);
519 Ok(())
520 })
521 .unwrap();
522 assert_eq!(rows, vec![vec!["a".to_string(), "b".to_string()]]);
523 assert_eq!(pending, b"c");
524
525 drain_copy_text_rows(&mut pending, b"\td\n", &mut |row: Vec<String>| {
526 rows.push(row);
527 Ok(())
528 })
529 .unwrap();
530 assert_eq!(
531 rows,
532 vec![
533 vec!["a".to_string(), "b".to_string()],
534 vec!["c".to_string(), "d".to_string()]
535 ]
536 );
537 assert!(pending.is_empty());
538 }
539
540 #[test]
541 fn flush_pending_copy_text_row_emits_final_partial_line() {
542 let mut pending = b"x\ty".to_vec();
543 let mut rows = Vec::new();
544 let mut on_row = |row: Vec<String>| -> PgResult<()> {
545 rows.push(row);
546 Ok(())
547 };
548
549 flush_pending_copy_text_row(&mut pending, &mut on_row).unwrap();
550 assert_eq!(rows, vec![vec!["x".to_string(), "y".to_string()]]);
551 assert!(pending.is_empty());
552 }
553
554 #[test]
555 fn callback_error_bubbles_from_row_drainer() {
556 let mut pending = Vec::new();
557 let mut on_row =
558 |_row: Vec<String>| -> PgResult<()> { Err(PgError::Query("fail".to_string())) };
559
560 let err = drain_copy_text_rows(&mut pending, b"a\tb\n", &mut on_row).unwrap_err();
561 assert!(matches!(err, PgError::Query(msg) if msg == "fail"));
562 }
563}