Skip to main content

qail_pg/driver/
copy.rs

1//! COPY protocol methods for PostgreSQL bulk operations.
2//!
3
4use super::{PgConnection, PgError, PgResult, parse_affected_rows};
5use crate::protocol::{AstEncoder, BackendMessage, PgEncoder};
6use bytes::BytesMut;
7use qail_core::ast::{Action, Qail};
8use tokio::io::AsyncWriteExt;
9
10/// Quote a SQL identifier to prevent injection.
11/// Wraps in double-quotes, escapes embedded double-quotes, and strips NUL bytes.
12fn quote_ident(ident: &str) -> String {
13    format!("\"{}\"", ident.replace('\0', "").replace('"', "\"\""))
14}
15
16impl PgConnection {
17    /// **Fast** bulk insert using COPY protocol with zero-allocation encoding.
18    /// Encodes all rows into a single buffer and writes with one syscall.
19    /// ~2x faster than `copy_in_internal` due to batched I/O.
20    pub(crate) async fn copy_in_fast(
21        &mut self,
22        table: &str,
23        columns: &[String],
24        rows: &[Vec<qail_core::ast::Value>],
25    ) -> PgResult<u64> {
26        use crate::protocol::encode_copy_batch;
27
28        let cols: Vec<String> = columns.iter().map(|c| quote_ident(c)).collect();
29        let sql = format!(
30            "COPY {} ({}) FROM STDIN",
31            quote_ident(table),
32            cols.join(", ")
33        );
34
35        // Send COPY command
36        let bytes = PgEncoder::encode_query_string(&sql);
37        self.stream.write_all(&bytes).await?;
38
39        // Wait for CopyInResponse
40        loop {
41            let msg = self.recv().await?;
42            match msg {
43                BackendMessage::CopyInResponse { .. } => break,
44                BackendMessage::ErrorResponse(err) => {
45                    return Err(PgError::QueryServer(err.into()));
46                }
47                _ => {}
48            }
49        }
50
51        // Encode ALL rows into a single buffer (zero-allocation per value)
52        let batch_data = encode_copy_batch(rows);
53
54        // Single write for entire batch!
55        self.send_copy_data(&batch_data).await?;
56
57        // Send CopyDone
58        self.send_copy_done().await?;
59
60        // Wait for CommandComplete
61        let mut affected = 0u64;
62        loop {
63            let msg = self.recv().await?;
64            match msg {
65                BackendMessage::CommandComplete(tag) => {
66                    affected = parse_affected_rows(&tag);
67                }
68                BackendMessage::ReadyForQuery(_) => {
69                    return Ok(affected);
70                }
71                BackendMessage::ErrorResponse(err) => {
72                    return Err(PgError::QueryServer(err.into()));
73                }
74                _ => {}
75            }
76        }
77    }
78
79    /// **Fastest** bulk insert using COPY protocol with pre-encoded data.
80    /// Accepts raw COPY text format bytes, no encoding needed.
81    /// Use when caller has already encoded rows to COPY format.
82    /// # Format
83    /// Data should be tab-separated rows with newlines:
84    /// `1\thello\t3.14\n2\tworld\t2.71\n`
85    pub async fn copy_in_raw(
86        &mut self,
87        table: &str,
88        columns: &[String],
89        data: &[u8],
90    ) -> PgResult<u64> {
91        let cols: Vec<String> = columns.iter().map(|c| quote_ident(c)).collect();
92        let sql = format!(
93            "COPY {} ({}) FROM STDIN",
94            quote_ident(table),
95            cols.join(", ")
96        );
97
98        // Send COPY command
99        let bytes = PgEncoder::encode_query_string(&sql);
100        self.stream.write_all(&bytes).await?;
101
102        // Wait for CopyInResponse
103        loop {
104            let msg = self.recv().await?;
105            match msg {
106                BackendMessage::CopyInResponse { .. } => break,
107                BackendMessage::ErrorResponse(err) => {
108                    return Err(PgError::QueryServer(err.into()));
109                }
110                _ => {}
111            }
112        }
113
114        // Single write - data is already encoded!
115        self.send_copy_data(data).await?;
116
117        // Send CopyDone
118        self.send_copy_done().await?;
119
120        // Wait for CommandComplete
121        let mut affected = 0u64;
122        loop {
123            let msg = self.recv().await?;
124            match msg {
125                BackendMessage::CommandComplete(tag) => {
126                    affected = parse_affected_rows(&tag);
127                }
128                BackendMessage::ReadyForQuery(_) => {
129                    return Ok(affected);
130                }
131                BackendMessage::ErrorResponse(err) => {
132                    return Err(PgError::QueryServer(err.into()));
133                }
134                _ => {}
135            }
136        }
137    }
138
139    /// Send CopyData message (raw bytes).
140    async fn send_copy_data(&mut self, data: &[u8]) -> PgResult<()> {
141        // CopyData: 'd' + length + data
142        let len = (data.len() + 4) as i32;
143        let mut buf = BytesMut::with_capacity(1 + 4 + data.len());
144        buf.extend_from_slice(b"d");
145        buf.extend_from_slice(&len.to_be_bytes());
146        buf.extend_from_slice(data);
147        self.stream.write_all(&buf).await?;
148        Ok(())
149    }
150
151    async fn send_copy_done(&mut self) -> PgResult<()> {
152        // CopyDone: 'c' + length (4)
153        self.stream.write_all(&[b'c', 0, 0, 0, 4]).await?;
154        Ok(())
155    }
156
157    /// Export data using COPY TO STDOUT (AST-native).
158    /// Takes a `Qail::Export` and returns rows as `Vec<Vec<String>>`.
159    /// # Example
160    /// ```ignore
161    /// let cmd = Qail::export("users")
162    ///     .columns(["id", "name"])
163    ///     .filter("active", true);
164    /// let rows = conn.copy_export(&cmd).await?;
165    /// ```
166    pub async fn copy_export(&mut self, cmd: &Qail) -> PgResult<Vec<Vec<String>>> {
167        if cmd.action != Action::Export {
168            return Err(PgError::Query(
169                "copy_export requires Qail::Export action".to_string(),
170            ));
171        }
172
173        // Encode command to SQL using AST encoder
174        let (sql, _params) =
175            AstEncoder::encode_cmd_sql(cmd).map_err(|e| PgError::Encode(e.to_string()))?;
176
177        // Send COPY command
178        let bytes = PgEncoder::encode_query_string(&sql);
179        self.stream.write_all(&bytes).await?;
180
181        // Wait for CopyOutResponse
182        loop {
183            let msg = self.recv().await?;
184            match msg {
185                BackendMessage::CopyOutResponse { .. } => break,
186                BackendMessage::ErrorResponse(err) => {
187                    return Err(PgError::QueryServer(err.into()));
188                }
189                _ => {}
190            }
191        }
192
193        // Receive CopyData messages until CopyDone
194        let mut rows = Vec::new();
195        loop {
196            let msg = self.recv().await?;
197            match msg {
198                BackendMessage::CopyData(data) => {
199                    let line = String::from_utf8_lossy(&data);
200                    let line = line.trim_end_matches('\n');
201                    let cols: Vec<String> = line.split('\t').map(|s| s.to_string()).collect();
202                    rows.push(cols);
203                }
204                BackendMessage::CopyDone => {}
205                BackendMessage::CommandComplete(_) => {}
206                BackendMessage::ReadyForQuery(_) => {
207                    return Ok(rows);
208                }
209                BackendMessage::ErrorResponse(err) => {
210                    return Err(PgError::QueryServer(err.into()));
211                }
212                _ => {}
213            }
214        }
215    }
216
217    /// Export data using raw COPY TO STDOUT, returning raw bytes.
218    /// Format: tab-separated values, newline-terminated rows.
219    /// Suitable for direct re-import via copy_in_raw.
220    ///
221    /// # Safety
222    /// `pub(crate)` — not exposed externally because callers pass raw SQL.
223    /// External code should use `copy_export()` with the AST encoder instead.
224    pub(crate) async fn copy_out_raw(&mut self, sql: &str) -> PgResult<Vec<u8>> {
225        // Send COPY command
226        let bytes = PgEncoder::encode_query_string(sql);
227        self.stream.write_all(&bytes).await?;
228
229        // Wait for CopyOutResponse
230        loop {
231            let msg = self.recv().await?;
232            match msg {
233                BackendMessage::CopyOutResponse { .. } => break,
234                BackendMessage::ErrorResponse(err) => {
235                    return Err(PgError::QueryServer(err.into()));
236                }
237                _ => {}
238            }
239        }
240
241        // Receive CopyData messages until CopyDone
242        let mut data = Vec::new();
243        loop {
244            let msg = self.recv().await?;
245            match msg {
246                BackendMessage::CopyData(chunk) => {
247                    data.extend_from_slice(&chunk);
248                }
249                BackendMessage::CopyDone => {}
250                BackendMessage::CommandComplete(_) => {}
251                BackendMessage::ReadyForQuery(_) => {
252                    return Ok(data);
253                }
254                BackendMessage::ErrorResponse(err) => {
255                    return Err(PgError::QueryServer(err.into()));
256                }
257                _ => {}
258            }
259        }
260    }
261}