Skip to main content

qail_pg/driver/
copy.rs

1//! COPY protocol methods for PostgreSQL bulk operations.
2//!
3
4use super::{
5    PgConnection, PgError, PgResult, is_ignorable_session_message, parse_affected_rows,
6    unexpected_backend_message,
7};
8use crate::protocol::{AstEncoder, BackendMessage, PgEncoder};
9use bytes::BytesMut;
10use qail_core::ast::{Action, Qail};
11use std::future::Future;
12
13/// Quote a SQL identifier to prevent injection.
14/// Wraps in double-quotes, escapes embedded double-quotes, and strips NUL bytes.
15fn quote_ident(ident: &str) -> String {
16    format!("\"{}\"", ident.replace('\0', "").replace('"', "\"\""))
17}
18
19fn parse_copy_text_row(line: &[u8]) -> Vec<String> {
20    let line = if line.ends_with(b"\r") {
21        &line[..line.len().saturating_sub(1)]
22    } else {
23        line
24    };
25    let text = String::from_utf8_lossy(line);
26    text.split('\t').map(|s| s.to_string()).collect()
27}
28
29fn drain_copy_text_rows<F>(pending: &mut Vec<u8>, chunk: &[u8], on_row: &mut F) -> PgResult<()>
30where
31    F: FnMut(Vec<String>) -> PgResult<()>,
32{
33    pending.extend_from_slice(chunk);
34    while let Some(pos) = pending.iter().position(|&b| b == b'\n') {
35        let line = pending[..pos].to_vec();
36        pending.drain(..=pos);
37        on_row(parse_copy_text_row(&line))?;
38    }
39    Ok(())
40}
41
42fn flush_pending_copy_text_row<F>(pending: &mut Vec<u8>, on_row: &mut F) -> PgResult<()>
43where
44    F: FnMut(Vec<String>) -> PgResult<()>,
45{
46    if pending.is_empty() {
47        return Ok(());
48    }
49    let line = std::mem::take(pending);
50    on_row(parse_copy_text_row(&line))
51}
52
53impl PgConnection {
54    /// **Fast** bulk insert using COPY protocol with zero-allocation encoding.
55    /// Encodes all rows into a single buffer and writes with one syscall.
56    /// ~2x faster than `copy_in_internal` due to batched I/O.
57    pub(crate) async fn copy_in_fast(
58        &mut self,
59        table: &str,
60        columns: &[String],
61        rows: &[Vec<qail_core::ast::Value>],
62    ) -> PgResult<u64> {
63        use crate::protocol::encode_copy_batch;
64
65        let cols: Vec<String> = columns.iter().map(|c| quote_ident(c)).collect();
66        let sql = format!(
67            "COPY {} ({}) FROM STDIN",
68            quote_ident(table),
69            cols.join(", ")
70        );
71
72        // Send COPY command
73        let bytes = PgEncoder::try_encode_query_string(&sql)?;
74        self.write_all_with_timeout(&bytes, "stream write").await?;
75
76        // Wait for CopyInResponse
77        let mut startup_error: Option<PgError> = None;
78        loop {
79            let msg = self.recv().await?;
80            match msg {
81                BackendMessage::CopyInResponse { .. } => {
82                    if let Some(err) = startup_error {
83                        return Err(err);
84                    }
85                    break;
86                }
87                BackendMessage::ReadyForQuery(_) => {
88                    return Err(startup_error.unwrap_or_else(|| {
89                        PgError::Protocol(
90                            "COPY IN failed before CopyInResponse (unexpected ReadyForQuery)"
91                                .to_string(),
92                        )
93                    }));
94                }
95                BackendMessage::ErrorResponse(err) => {
96                    if startup_error.is_none() {
97                        startup_error = Some(PgError::QueryServer(err.into()));
98                    }
99                }
100                msg if is_ignorable_session_message(&msg) => {}
101                other => {
102                    return Err(unexpected_backend_message("copy-in startup", &other));
103                }
104            }
105        }
106
107        // Encode ALL rows into a single buffer (zero-allocation per value)
108        let batch_data = encode_copy_batch(rows);
109
110        // Single write for entire batch!
111        self.send_copy_data(&batch_data).await?;
112
113        // Send CopyDone
114        self.send_copy_done().await?;
115
116        // Wait for CommandComplete
117        let mut affected = 0u64;
118        let mut final_error: Option<PgError> = None;
119        let mut saw_command_complete = false;
120        loop {
121            let msg = self.recv().await?;
122            match msg {
123                BackendMessage::CommandComplete(tag) => {
124                    if saw_command_complete {
125                        return Err(PgError::Protocol(
126                            "COPY IN received duplicate CommandComplete".to_string(),
127                        ));
128                    }
129                    saw_command_complete = true;
130                    if final_error.is_none() {
131                        affected = parse_affected_rows(&tag);
132                    }
133                }
134                BackendMessage::ReadyForQuery(_) => {
135                    if let Some(err) = final_error {
136                        return Err(err);
137                    }
138                    if !saw_command_complete {
139                        return Err(PgError::Protocol(
140                            "COPY IN completion missing CommandComplete before ReadyForQuery"
141                                .to_string(),
142                        ));
143                    }
144                    return Ok(affected);
145                }
146                BackendMessage::ErrorResponse(err) => {
147                    if final_error.is_none() {
148                        final_error = Some(PgError::QueryServer(err.into()));
149                    }
150                }
151                msg if is_ignorable_session_message(&msg) => {}
152                other => {
153                    return Err(unexpected_backend_message("copy-in completion", &other));
154                }
155            }
156        }
157    }
158
159    /// **Fastest** bulk insert using COPY protocol with pre-encoded data.
160    /// Accepts raw COPY text format bytes, no encoding needed.
161    /// Use when caller has already encoded rows to COPY format.
162    /// # Format
163    /// Data should be tab-separated rows with newlines:
164    /// `1\thello\t3.14\n2\tworld\t2.71\n`
165    pub async fn copy_in_raw(
166        &mut self,
167        table: &str,
168        columns: &[String],
169        data: &[u8],
170    ) -> PgResult<u64> {
171        let cols: Vec<String> = columns.iter().map(|c| quote_ident(c)).collect();
172        let sql = format!(
173            "COPY {} ({}) FROM STDIN",
174            quote_ident(table),
175            cols.join(", ")
176        );
177
178        // Send COPY command
179        let bytes = PgEncoder::try_encode_query_string(&sql)?;
180        self.write_all_with_timeout(&bytes, "stream write").await?;
181
182        // Wait for CopyInResponse
183        let mut startup_error: Option<PgError> = None;
184        loop {
185            let msg = self.recv().await?;
186            match msg {
187                BackendMessage::CopyInResponse { .. } => {
188                    if let Some(err) = startup_error {
189                        return Err(err);
190                    }
191                    break;
192                }
193                BackendMessage::ReadyForQuery(_) => {
194                    return Err(startup_error.unwrap_or_else(|| {
195                        PgError::Protocol(
196                            "COPY IN failed before CopyInResponse (unexpected ReadyForQuery)"
197                                .to_string(),
198                        )
199                    }));
200                }
201                BackendMessage::ErrorResponse(err) => {
202                    if startup_error.is_none() {
203                        startup_error = Some(PgError::QueryServer(err.into()));
204                    }
205                }
206                msg if is_ignorable_session_message(&msg) => {}
207                other => {
208                    return Err(unexpected_backend_message("copy-in raw startup", &other));
209                }
210            }
211        }
212
213        // Single write - data is already encoded!
214        self.send_copy_data(data).await?;
215
216        // Send CopyDone
217        self.send_copy_done().await?;
218
219        // Wait for CommandComplete
220        let mut affected = 0u64;
221        let mut final_error: Option<PgError> = None;
222        let mut saw_command_complete = false;
223        loop {
224            let msg = self.recv().await?;
225            match msg {
226                BackendMessage::CommandComplete(tag) => {
227                    if saw_command_complete {
228                        return Err(PgError::Protocol(
229                            "COPY IN raw received duplicate CommandComplete".to_string(),
230                        ));
231                    }
232                    saw_command_complete = true;
233                    if final_error.is_none() {
234                        affected = parse_affected_rows(&tag);
235                    }
236                }
237                BackendMessage::ReadyForQuery(_) => {
238                    if let Some(err) = final_error {
239                        return Err(err);
240                    }
241                    if !saw_command_complete {
242                        return Err(PgError::Protocol(
243                            "COPY IN raw completion missing CommandComplete before ReadyForQuery"
244                                .to_string(),
245                        ));
246                    }
247                    return Ok(affected);
248                }
249                BackendMessage::ErrorResponse(err) => {
250                    if final_error.is_none() {
251                        final_error = Some(PgError::QueryServer(err.into()));
252                    }
253                }
254                msg if is_ignorable_session_message(&msg) => {}
255                other => {
256                    return Err(unexpected_backend_message("copy-in raw completion", &other));
257                }
258            }
259        }
260    }
261
262    /// Send CopyData message (raw bytes).
263    pub(crate) async fn send_copy_data(&mut self, data: &[u8]) -> PgResult<()> {
264        let total_len = data
265            .len()
266            .checked_add(4)
267            .ok_or_else(|| PgError::Protocol("CopyData frame length overflow".to_string()))?;
268        let len = i32::try_from(total_len)
269            .map_err(|_| PgError::Protocol("CopyData frame exceeds i32::MAX".to_string()))?;
270
271        // CopyData: 'd' + length + data
272        let mut buf = BytesMut::with_capacity(1 + 4 + data.len());
273        buf.extend_from_slice(b"d");
274        buf.extend_from_slice(&len.to_be_bytes());
275        buf.extend_from_slice(data);
276        self.write_all_with_timeout(&buf, "stream write").await?;
277        Ok(())
278    }
279
280    async fn send_copy_done(&mut self) -> PgResult<()> {
281        // CopyDone: 'c' + length (4)
282        self.write_all_with_timeout(&[b'c', 0, 0, 0, 4], "stream write")
283            .await?;
284        Ok(())
285    }
286
287    async fn start_copy_out(&mut self, sql: &str, context: &str) -> PgResult<()> {
288        let bytes = PgEncoder::try_encode_query_string(sql)?;
289        self.write_all_with_timeout(&bytes, "stream write").await?;
290
291        let mut startup_error: Option<PgError> = None;
292        loop {
293            let msg = self.recv().await?;
294            match msg {
295                BackendMessage::CopyOutResponse { .. } => {
296                    if let Some(err) = startup_error {
297                        return Err(err);
298                    }
299                    return Ok(());
300                }
301                BackendMessage::ReadyForQuery(_) => {
302                    return Err(startup_error.unwrap_or_else(|| {
303                        PgError::Protocol(format!(
304                            "{} failed before CopyOutResponse (unexpected ReadyForQuery)",
305                            context
306                        ))
307                    }));
308                }
309                BackendMessage::ErrorResponse(err) => {
310                    if startup_error.is_none() {
311                        startup_error = Some(PgError::QueryServer(err.into()));
312                    }
313                }
314                msg if is_ignorable_session_message(&msg) => {}
315                other => return Err(unexpected_backend_message(context, &other)),
316            }
317        }
318    }
319
320    async fn stream_copy_out_chunks<F, Fut>(
321        &mut self,
322        context: &str,
323        mut on_chunk: F,
324    ) -> PgResult<()>
325    where
326        F: FnMut(Vec<u8>) -> Fut,
327        Fut: Future<Output = PgResult<()>>,
328    {
329        let mut stream_error: Option<PgError> = None;
330        let mut callback_error: Option<PgError> = None;
331        let mut saw_copy_done = false;
332        let mut saw_command_complete = false;
333
334        loop {
335            let msg = self.recv().await?;
336            match msg {
337                BackendMessage::CopyData(chunk) => {
338                    if saw_copy_done {
339                        return Err(PgError::Protocol(format!(
340                            "{} received CopyData after CopyDone",
341                            context
342                        )));
343                    }
344                    if stream_error.is_none()
345                        && callback_error.is_none()
346                        && let Err(e) = on_chunk(chunk).await
347                    {
348                        callback_error = Some(e);
349                    }
350                }
351                BackendMessage::CopyDone => {
352                    if saw_copy_done {
353                        return Err(PgError::Protocol(format!(
354                            "{} received duplicate CopyDone",
355                            context
356                        )));
357                    }
358                    saw_copy_done = true;
359                }
360                BackendMessage::CommandComplete(_) => {
361                    if saw_command_complete {
362                        return Err(PgError::Protocol(format!(
363                            "{} received duplicate CommandComplete",
364                            context
365                        )));
366                    }
367                    saw_command_complete = true;
368                }
369                BackendMessage::ReadyForQuery(_) => {
370                    if let Some(err) = stream_error {
371                        return Err(err);
372                    }
373                    if let Some(err) = callback_error {
374                        return Err(err);
375                    }
376                    if !saw_copy_done {
377                        return Err(PgError::Protocol(format!(
378                            "{} missing CopyDone before ReadyForQuery",
379                            context
380                        )));
381                    }
382                    if !saw_command_complete {
383                        return Err(PgError::Protocol(format!(
384                            "{} missing CommandComplete before ReadyForQuery",
385                            context
386                        )));
387                    }
388                    return Ok(());
389                }
390                BackendMessage::ErrorResponse(err) => {
391                    if stream_error.is_none() {
392                        stream_error = Some(PgError::QueryServer(err.into()));
393                    }
394                }
395                msg if is_ignorable_session_message(&msg) => {}
396                other => return Err(unexpected_backend_message(context, &other)),
397            }
398        }
399    }
400
401    /// Export data using COPY TO STDOUT (AST-native).
402    /// Takes a `Qail::Export` and returns rows as `Vec<Vec<String>>`.
403    /// # Example
404    /// ```ignore
405    /// let cmd = Qail::export("users")
406    ///     .columns(["id", "name"])
407    ///     .filter("active", true);
408    /// let rows = conn.copy_export(&cmd).await?;
409    /// ```
410    pub async fn copy_export(&mut self, cmd: &Qail) -> PgResult<Vec<Vec<String>>> {
411        let mut rows = Vec::new();
412        self.copy_export_stream_rows(cmd, |row| {
413            rows.push(row);
414            Ok(())
415        })
416        .await?;
417        Ok(rows)
418    }
419
420    /// Stream COPY TO STDOUT chunks using an AST-native `Qail::Export` command.
421    ///
422    /// Chunks are forwarded as they arrive from PostgreSQL, so memory usage
423    /// stays bounded by network frame size and callback processing.
424    pub async fn copy_export_stream_raw<F, Fut>(&mut self, cmd: &Qail, on_chunk: F) -> PgResult<()>
425    where
426        F: FnMut(Vec<u8>) -> Fut,
427        Fut: Future<Output = PgResult<()>>,
428    {
429        if cmd.action != Action::Export {
430            return Err(PgError::Query(
431                "copy_export requires Qail::Export action".to_string(),
432            ));
433        }
434
435        // Encode command to SQL using AST encoder
436        let (sql, _params) =
437            AstEncoder::encode_cmd_sql(cmd).map_err(|e| PgError::Encode(e.to_string()))?;
438
439        self.copy_out_raw_stream(&sql, on_chunk).await
440    }
441
442    /// Stream COPY TO STDOUT rows using an AST-native `Qail::Export` command.
443    ///
444    /// Parses PostgreSQL COPY text lines into `Vec<String>` rows and invokes
445    /// `on_row` for each row without buffering the full result.
446    pub async fn copy_export_stream_rows<F>(&mut self, cmd: &Qail, mut on_row: F) -> PgResult<()>
447    where
448        F: FnMut(Vec<String>) -> PgResult<()>,
449    {
450        let mut pending = Vec::new();
451        self.copy_export_stream_raw(cmd, |chunk| {
452            let res = drain_copy_text_rows(&mut pending, &chunk, &mut on_row);
453            std::future::ready(res)
454        })
455        .await?;
456        flush_pending_copy_text_row(&mut pending, &mut on_row)
457    }
458
459    /// Export data using raw COPY TO STDOUT, returning raw bytes.
460    /// Format: tab-separated values, newline-terminated rows.
461    /// Suitable for direct re-import via copy_in_raw.
462    ///
463    /// # Safety
464    /// `pub(crate)` — not exposed externally because callers pass raw SQL.
465    /// External code should use `copy_export()` with the AST encoder instead.
466    pub(crate) async fn copy_out_raw(&mut self, sql: &str) -> PgResult<Vec<u8>> {
467        let mut data = Vec::new();
468        self.copy_out_raw_stream(sql, |chunk| {
469            data.extend_from_slice(&chunk);
470            std::future::ready(Ok(()))
471        })
472        .await?;
473        Ok(data)
474    }
475
476    /// Stream raw COPY TO STDOUT bytes with bounded memory usage.
477    ///
478    /// # Safety
479    /// `pub(crate)` — callers pass raw SQL.
480    pub(crate) async fn copy_out_raw_stream<F, Fut>(
481        &mut self,
482        sql: &str,
483        on_chunk: F,
484    ) -> PgResult<()>
485    where
486        F: FnMut(Vec<u8>) -> Fut,
487        Fut: Future<Output = PgResult<()>>,
488    {
489        self.start_copy_out(sql, "copy-out raw startup").await?;
490        self.stream_copy_out_chunks("copy-out raw stream", on_chunk)
491            .await
492    }
493}
494
495#[cfg(test)]
496mod tests {
497    use super::{drain_copy_text_rows, flush_pending_copy_text_row, parse_copy_text_row};
498    use crate::driver::{PgError, PgResult};
499
500    #[test]
501    fn parse_copy_text_row_splits_tabs() {
502        let row = parse_copy_text_row(b"a\tb\tc");
503        assert_eq!(row, vec!["a", "b", "c"]);
504    }
505
506    #[test]
507    fn parse_copy_text_row_trims_cr() {
508        let row = parse_copy_text_row(b"a\tb\r");
509        assert_eq!(row, vec!["a", "b"]);
510    }
511
512    #[test]
513    fn drain_copy_text_rows_handles_chunk_boundaries() {
514        let mut pending = Vec::new();
515        let mut rows: Vec<Vec<String>> = Vec::new();
516
517        drain_copy_text_rows(&mut pending, b"a\tb\nc", &mut |row: Vec<String>| {
518            rows.push(row);
519            Ok(())
520        })
521        .unwrap();
522        assert_eq!(rows, vec![vec!["a".to_string(), "b".to_string()]]);
523        assert_eq!(pending, b"c");
524
525        drain_copy_text_rows(&mut pending, b"\td\n", &mut |row: Vec<String>| {
526            rows.push(row);
527            Ok(())
528        })
529        .unwrap();
530        assert_eq!(
531            rows,
532            vec![
533                vec!["a".to_string(), "b".to_string()],
534                vec!["c".to_string(), "d".to_string()]
535            ]
536        );
537        assert!(pending.is_empty());
538    }
539
540    #[test]
541    fn flush_pending_copy_text_row_emits_final_partial_line() {
542        let mut pending = b"x\ty".to_vec();
543        let mut rows = Vec::new();
544        let mut on_row = |row: Vec<String>| -> PgResult<()> {
545            rows.push(row);
546            Ok(())
547        };
548
549        flush_pending_copy_text_row(&mut pending, &mut on_row).unwrap();
550        assert_eq!(rows, vec![vec!["x".to_string(), "y".to_string()]]);
551        assert!(pending.is_empty());
552    }
553
554    #[test]
555    fn callback_error_bubbles_from_row_drainer() {
556        let mut pending = Vec::new();
557        let mut on_row =
558            |_row: Vec<String>| -> PgResult<()> { Err(PgError::Query("fail".to_string())) };
559
560        let err = drain_copy_text_rows(&mut pending, b"a\tb\n", &mut on_row).unwrap_err();
561        assert!(matches!(err, PgError::Query(msg) if msg == "fail"));
562    }
563}